classifier 2.0.0 → 2.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,61 +6,110 @@
6
6
 
7
7
  module Classifier
8
8
  class LSI
9
- # @rbs @gsl_available: bool
10
- @gsl_available = false
9
+ # Backend options: :native, :ruby
10
+ # @rbs @backend: Symbol
11
+ @backend = :ruby
11
12
 
12
13
  class << self
13
- # @rbs @gsl_available: bool
14
- attr_accessor :gsl_available
14
+ # @rbs @backend: Symbol
15
+ attr_accessor :backend
16
+
17
+ # Check if using native C extension
18
+ # @rbs () -> bool
19
+ def native_available?
20
+ backend == :native
21
+ end
22
+
23
+ # Get the Vector class for the current backend
24
+ # @rbs () -> Class
25
+ def vector_class
26
+ backend == :native ? Classifier::Linalg::Vector : ::Vector
27
+ end
28
+
29
+ # Get the Matrix class for the current backend
30
+ # @rbs () -> Class
31
+ def matrix_class
32
+ backend == :native ? Classifier::Linalg::Matrix : ::Matrix
33
+ end
15
34
  end
16
35
  end
17
36
  end
18
37
 
38
+ # Backend detection: native extension > pure Ruby
39
+ # Set NATIVE_VECTOR=true to force pure Ruby implementation
40
+
19
41
  begin
20
- # to test the native vector class, try `rake test NATIVE_VECTOR=true`
21
42
  raise LoadError if ENV['NATIVE_VECTOR'] == 'true'
22
- raise LoadError unless Gem::Specification.find_all_by_name('gsl').any?
23
43
 
24
- require 'gsl'
25
- require 'classifier/extensions/vector_serialize'
26
- Classifier::LSI.gsl_available = true
44
+ require 'classifier/classifier_ext'
45
+ Classifier::LSI.backend = :native
27
46
  rescue LoadError
28
- unless ENV['SUPPRESS_GSL_WARNING'] == 'true'
29
- warn 'Notice: for 10x faster LSI, run `gem install gsl`. Set SUPPRESS_GSL_WARNING=true to hide this.'
47
+ # Fall back to pure Ruby implementation
48
+ unless ENV['SUPPRESS_LSI_WARNING'] == 'true'
49
+ warn 'Notice: for 5-10x faster LSI, install the classifier gem with native extensions. ' \
50
+ 'Set SUPPRESS_LSI_WARNING=true to hide this.'
30
51
  end
31
- Classifier::LSI.gsl_available = false
52
+ Classifier::LSI.backend = :ruby
32
53
  require 'classifier/extensions/vector'
33
54
  end
34
55
 
56
+ require 'json'
57
+ require 'mutex_m'
35
58
  require 'classifier/lsi/word_list'
36
59
  require 'classifier/lsi/content_node'
37
60
  require 'classifier/lsi/summary'
61
+ require 'classifier/lsi/incremental_svd'
38
62
 
39
63
  module Classifier
40
64
  # This class implements a Latent Semantic Indexer, which can search, classify and cluster
41
65
  # data based on underlying semantic relations. For more information on the algorithms used,
42
66
  # please consult Wikipedia[http://en.wikipedia.org/wiki/Latent_Semantic_Indexing].
43
67
  class LSI
68
+ include Mutex_m
69
+ include Streaming
70
+
44
71
  # @rbs @auto_rebuild: bool
45
72
  # @rbs @word_list: WordList
46
73
  # @rbs @items: Hash[untyped, ContentNode]
47
74
  # @rbs @version: Integer
48
75
  # @rbs @built_at_version: Integer
76
+ # @rbs @singular_values: Array[Float]?
77
+ # @rbs @dirty: bool
78
+ # @rbs @storage: Storage::Base?
79
+ # @rbs @incremental_mode: bool
80
+ # @rbs @u_matrix: Matrix?
81
+ # @rbs @max_rank: Integer
82
+ # @rbs @initial_vocab_size: Integer?
49
83
 
50
- attr_reader :word_list
51
- attr_accessor :auto_rebuild
84
+ attr_reader :word_list, :singular_values
85
+ attr_accessor :auto_rebuild, :storage
86
+
87
+ # Default maximum rank for incremental SVD
88
+ DEFAULT_MAX_RANK = 100
52
89
 
53
90
  # Create a fresh index.
54
91
  # If you want to call #build_index manually, use
55
92
  # Classifier::LSI.new auto_rebuild: false
56
93
  #
94
+ # For incremental SVD mode (adds documents without full rebuild):
95
+ # Classifier::LSI.new incremental: true, max_rank: 100
96
+ #
57
97
  # @rbs (?Hash[Symbol, untyped]) -> void
58
98
  def initialize(options = {})
99
+ super()
59
100
  @auto_rebuild = true unless options[:auto_rebuild] == false
60
101
  @word_list = WordList.new
61
102
  @items = {}
62
103
  @version = 0
63
104
  @built_at_version = -1
105
+ @dirty = false
106
+ @storage = nil
107
+
108
+ # Incremental SVD settings
109
+ @incremental_mode = options[:incremental] == true
110
+ @max_rank = options[:max_rank] || DEFAULT_MAX_RANK
111
+ @u_matrix = nil
112
+ @initial_vocab_size = nil
64
113
  end
65
114
 
66
115
  # Returns true if the index needs to be rebuilt. The index needs
@@ -69,7 +118,85 @@ module Classifier
69
118
  #
70
119
  # @rbs () -> bool
71
120
  def needs_rebuild?
72
- (@items.keys.size > 1) && (@version != @built_at_version)
121
+ synchronize { (@items.keys.size > 1) && (@version != @built_at_version) }
122
+ end
123
+
124
+ # @rbs () -> Array[Hash[Symbol, untyped]]?
125
+ def singular_value_spectrum
126
+ return nil unless @singular_values
127
+
128
+ total = @singular_values.sum
129
+ return nil if total.zero?
130
+
131
+ cumulative = 0.0
132
+ @singular_values.map.with_index do |value, i|
133
+ cumulative += value
134
+ {
135
+ dimension: i,
136
+ value: value,
137
+ percentage: value / total,
138
+ cumulative_percentage: cumulative / total
139
+ }
140
+ end
141
+ end
142
+
143
+ # Returns true if incremental mode is enabled and active.
144
+ # Incremental mode becomes active after the first build_index call.
145
+ #
146
+ # @rbs () -> bool
147
+ def incremental_enabled?
148
+ @incremental_mode && !@u_matrix.nil?
149
+ end
150
+
151
+ # Returns the current rank of the incremental SVD (number of singular values kept).
152
+ # Returns nil if incremental mode is not active.
153
+ #
154
+ # @rbs () -> Integer?
155
+ def current_rank
156
+ @singular_values&.count(&:positive?)
157
+ end
158
+
159
+ # Disables incremental mode. Subsequent adds will trigger full rebuilds.
160
+ #
161
+ # @rbs () -> void
162
+ def disable_incremental_mode!
163
+ @incremental_mode = false
164
+ @u_matrix = nil
165
+ @initial_vocab_size = nil
166
+ end
167
+
168
+ # Enables incremental mode with optional max_rank setting.
169
+ # The next build_index call will store the U matrix for incremental updates.
170
+ #
171
+ # @rbs (?max_rank: Integer) -> void
172
+ def enable_incremental_mode!(max_rank: DEFAULT_MAX_RANK)
173
+ @incremental_mode = true
174
+ @max_rank = max_rank
175
+ end
176
+
177
+ # Adds items to the index using hash-style syntax.
178
+ # The hash keys are categories, and values are items (or arrays of items).
179
+ #
180
+ # For example:
181
+ # lsi = Classifier::LSI.new
182
+ # lsi.add("Dog" => "Dogs are loyal pets")
183
+ # lsi.add("Cat" => "Cats are independent")
184
+ # lsi.add(Bird: "Birds can fly") # Symbol keys work too
185
+ #
186
+ # Multiple items with the same category:
187
+ # lsi.add("Dog" => ["Dogs are loyal", "Puppies are cute"])
188
+ #
189
+ # Batch operations with multiple categories:
190
+ # lsi.add(
191
+ # "Dog" => ["Dogs are loyal", "Puppies are cute"],
192
+ # "Cat" => ["Cats are independent", "Kittens are playful"]
193
+ # )
194
+ #
195
+ # @rbs (**untyped items) -> void
196
+ def add(**items)
197
+ items.each do |category, value|
198
+ Array(value).each { |doc| add_item(doc, category.to_s) }
199
+ end
73
200
  end
74
201
 
75
202
  # Adds an item to the index. item is assumed to be a string, but
@@ -78,6 +205,8 @@ module Classifier
78
205
  # fetch fresh string data. This optional block is passed the item,
79
206
  # so the item may only be a reference to a URL or file name.
80
207
  #
208
+ # @deprecated Use {#add} instead for clearer hash-style syntax.
209
+ #
81
210
  # For example:
82
211
  # lsi = Classifier::LSI.new
83
212
  # lsi.add_item "This is just plain text"
@@ -88,8 +217,18 @@ module Classifier
88
217
  # @rbs (String, *String | Symbol) ?{ (String) -> String } -> void
89
218
  def add_item(item, *categories, &block)
90
219
  clean_word_hash = block ? block.call(item).clean_word_hash : item.to_s.clean_word_hash
91
- @items[item] = ContentNode.new(clean_word_hash, *categories)
92
- @version += 1
220
+ node = nil
221
+
222
+ synchronize do
223
+ node = ContentNode.new(clean_word_hash, *categories)
224
+ @items[item] = node
225
+ @version += 1
226
+ @dirty = true
227
+ end
228
+
229
+ # Use incremental update if enabled and we have a U matrix
230
+ return perform_incremental_update(node, clean_word_hash) if @incremental_mode && @u_matrix
231
+
93
232
  build_index if @auto_rebuild
94
233
  end
95
234
 
@@ -107,25 +246,32 @@ module Classifier
107
246
  #
108
247
  # @rbs (String) -> Array[String | Symbol]
109
248
  def categories_for(item)
110
- return [] unless @items[item]
249
+ synchronize do
250
+ return [] unless @items[item]
111
251
 
112
- @items[item].categories
252
+ @items[item].categories
253
+ end
113
254
  end
114
255
 
115
256
  # Removes an item from the database, if it is indexed.
116
257
  #
117
258
  # @rbs (String) -> void
118
259
  def remove_item(item)
119
- return unless @items.key?(item)
260
+ removed = synchronize do
261
+ next false unless @items.key?(item)
120
262
 
121
- @items.delete(item)
122
- @version += 1
263
+ @items.delete(item)
264
+ @version += 1
265
+ @dirty = true
266
+ true
267
+ end
268
+ build_index if removed && @auto_rebuild
123
269
  end
124
270
 
125
271
  # Returns an array of items that are indexed.
126
272
  # @rbs () -> Array[untyped]
127
273
  def items
128
- @items.keys
274
+ synchronize { @items.keys }
129
275
  end
130
276
 
131
277
  # This function rebuilds the index if needs_rebuild? returns true.
@@ -143,40 +289,38 @@ module Classifier
143
289
  # A value of 1 for cutoff means that no semantic analysis will take place,
144
290
  # turning the LSI class into a simple vector search engine.
145
291
  #
146
- # @rbs (?Float) -> void
147
- def build_index(cutoff = 0.75)
148
- return unless needs_rebuild?
292
+ # @rbs (?Float, ?force: bool) -> void
293
+ def build_index(cutoff = 0.75, force: false)
294
+ validate_cutoff!(cutoff)
149
295
 
150
- make_word_list
296
+ synchronize do
297
+ return unless force || needs_rebuild_unlocked?
151
298
 
152
- doc_list = @items.values
153
- tda = doc_list.collect { |node| node.raw_vector_with(@word_list) }
299
+ make_word_list
154
300
 
155
- if self.class.gsl_available
156
- tdm = GSL::Matrix.alloc(*tda).trans
157
- ntdm = build_reduced_matrix(tdm, cutoff)
301
+ doc_list = @items.values
302
+ tda = doc_list.collect { |node| node.raw_vector_with(@word_list) }
158
303
 
159
- ntdm.size[1].times do |col|
160
- vec = GSL::Vector.alloc(ntdm.column(col)).row
161
- doc_list[col].lsi_vector = vec
162
- doc_list[col].lsi_norm = vec.normalize
304
+ if self.class.native_available?
305
+ # Convert vectors to arrays for matrix construction
306
+ tda_arrays = tda.map { |v| v.respond_to?(:to_a) ? v.to_a : v }
307
+ tdm = self.class.matrix_class.alloc(*tda_arrays).trans
308
+ ntdm, u_mat = build_reduced_matrix_with_u(tdm, cutoff)
309
+ assign_native_ext_lsi_vectors(ntdm, doc_list)
310
+ else
311
+ tdm = Matrix.rows(tda).trans
312
+ ntdm, u_mat = build_reduced_matrix_with_u(tdm, cutoff)
313
+ assign_ruby_lsi_vectors(ntdm, doc_list)
163
314
  end
164
- else
165
- tdm = Matrix.rows(tda).trans
166
- ntdm = build_reduced_matrix(tdm, cutoff)
167
315
 
168
- ntdm.column_size.times do |col|
169
- next unless doc_list[col]
170
-
171
- column = ntdm.column(col)
172
- next unless column
173
-
174
- doc_list[col].lsi_vector = column
175
- doc_list[col].lsi_norm = column.normalize
316
+ # Store U matrix for incremental mode
317
+ if @incremental_mode
318
+ @u_matrix = u_mat
319
+ @initial_vocab_size = @word_list.size
176
320
  end
177
- end
178
321
 
179
- @built_at_version = @version
322
+ @built_at_version = @version
323
+ end
180
324
  end
181
325
 
182
326
  # This method returns max_chunks entries, ordered by their average semantic rating.
@@ -190,12 +334,14 @@ module Classifier
190
334
  #
191
335
  # @rbs (?Integer) -> Array[String]
192
336
  def highest_relative_content(max_chunks = 10)
193
- return [] if needs_rebuild?
337
+ synchronize do
338
+ return [] if needs_rebuild_unlocked?
194
339
 
195
- avg_density = {}
196
- @items.each_key { |x| avg_density[x] = proximity_array_for_content(x).sum { |pair| pair[1] } }
340
+ avg_density = {}
341
+ @items.each_key { |x| avg_density[x] = proximity_array_for_content_unlocked(x).sum { |pair| pair[1] } }
197
342
 
198
- avg_density.keys.sort_by { |x| avg_density[x] }.reverse[0..(max_chunks - 1)].map
343
+ avg_density.keys.sort_by { |x| avg_density[x] }.reverse[0..(max_chunks - 1)].map
344
+ end
199
345
  end
200
346
 
201
347
  # This function is the primitive that find_related and classify
@@ -212,20 +358,8 @@ module Classifier
212
358
  # text data. See add_item for examples of how this works.
213
359
  #
214
360
  # @rbs (String) ?{ (String) -> String } -> Array[[String, Float]]
215
- def proximity_array_for_content(doc, &)
216
- return [] if needs_rebuild?
217
-
218
- content_node = node_for_content(doc, &)
219
- result =
220
- @items.keys.collect do |item|
221
- val = if self.class.gsl_available
222
- content_node.search_vector * @items[item].search_vector.col
223
- else
224
- (Matrix[content_node.search_vector] * @items[item].search_vector)[0]
225
- end
226
- [item, val]
227
- end
228
- result.sort_by { |x| x[1] }.reverse
361
+ def proximity_array_for_content(doc, &block)
362
+ synchronize { proximity_array_for_content_unlocked(doc, &block) }
229
363
  end
230
364
 
231
365
  # Similar to proximity_array_for_content, this function takes similar
@@ -235,20 +369,8 @@ module Classifier
235
369
  # the text you're working with. search uses this primitive.
236
370
  #
237
371
  # @rbs (String) ?{ (String) -> String } -> Array[[String, Float]]
238
- def proximity_norms_for_content(doc, &)
239
- return [] if needs_rebuild?
240
-
241
- content_node = node_for_content(doc, &)
242
- result =
243
- @items.keys.collect do |item|
244
- val = if self.class.gsl_available
245
- content_node.search_norm * @items[item].search_norm.col
246
- else
247
- (Matrix[content_node.search_norm] * @items[item].search_norm)[0]
248
- end
249
- [item, val]
250
- end
251
- result.sort_by { |x| x[1] }.reverse
372
+ def proximity_norms_for_content(doc, &block)
373
+ synchronize { proximity_norms_for_content_unlocked(doc, &block) }
252
374
  end
253
375
 
254
376
  # This function allows for text-based search of your index. Unlike other functions
@@ -261,11 +383,13 @@ module Classifier
261
383
  #
262
384
  # @rbs (String, ?Integer) -> Array[String]
263
385
  def search(string, max_nearest = 3)
264
- return [] if needs_rebuild?
386
+ synchronize do
387
+ return [] if needs_rebuild_unlocked?
265
388
 
266
- carry = proximity_norms_for_content(string)
267
- result = carry.collect { |x| x[0] }
268
- result[0..(max_nearest - 1)]
389
+ carry = proximity_norms_for_content_unlocked(string)
390
+ result = carry.collect { |x| x[0] }
391
+ result[0..(max_nearest - 1)]
392
+ end
269
393
  end
270
394
 
271
395
  # This function takes content and finds other documents
@@ -280,10 +404,12 @@ module Classifier
280
404
  #
281
405
  # @rbs (String, ?Integer) ?{ (String) -> String } -> Array[String]
282
406
  def find_related(doc, max_nearest = 3, &block)
283
- carry =
284
- proximity_array_for_content(doc, &block).reject { |pair| pair[0] == doc }
285
- result = carry.collect { |x| x[0] }
286
- result[0..(max_nearest - 1)]
407
+ synchronize do
408
+ carry =
409
+ proximity_array_for_content_unlocked(doc, &block).reject { |pair| pair[0] == doc }
410
+ result = carry.collect { |x| x[0] }
411
+ result[0..(max_nearest - 1)]
412
+ end
287
413
  end
288
414
 
289
415
  # This function uses a voting system to categorize documents, based on
@@ -291,32 +417,23 @@ module Classifier
291
417
  # find_related function to find related documents, then returns the
292
418
  # most obvious category from this list.
293
419
  #
294
- # cutoff signifies the number of documents to consider when clasifying
295
- # text. A cutoff of 1 means that every document in the index votes on
296
- # what category the document is in. This may not always make sense.
297
- #
298
420
  # @rbs (String, ?Float) ?{ (String) -> String } -> String | Symbol
299
- def classify(doc, cutoff = 0.30, &)
300
- votes = vote(doc, cutoff, &)
421
+ def classify(doc, cutoff = 0.30, &block)
422
+ validate_cutoff!(cutoff)
423
+
424
+ synchronize do
425
+ votes = vote_unlocked(doc, cutoff, &block)
301
426
 
302
- ranking = votes.keys.sort_by { |x| votes[x] }
303
- ranking[-1]
427
+ ranking = votes.keys.sort_by { |x| votes[x] }
428
+ ranking[-1]
429
+ end
304
430
  end
305
431
 
306
432
  # @rbs (String, ?Float) ?{ (String) -> String } -> Hash[String | Symbol, Float]
307
- def vote(doc, cutoff = 0.30, &)
308
- icutoff = (@items.size * cutoff).round
309
- carry = proximity_array_for_content(doc, &)
310
- carry = carry[0..(icutoff - 1)]
311
- votes = {}
312
- carry.each do |pair|
313
- categories = @items[pair[0]].categories
314
- categories.each do |category|
315
- votes[category] ||= 0.0
316
- votes[category] += pair[1]
317
- end
318
- end
319
- votes
433
+ def vote(doc, cutoff = 0.30, &block)
434
+ validate_cutoff!(cutoff)
435
+
436
+ synchronize { vote_unlocked(doc, cutoff, &block) }
320
437
  end
321
438
 
322
439
  # Returns the same category as classify() but also returns
@@ -331,15 +448,19 @@ module Classifier
331
448
  #
332
449
  # See classify() for argument docs
333
450
  # @rbs (String, ?Float) ?{ (String) -> String } -> [String | Symbol | nil, Float?]
334
- def classify_with_confidence(doc, cutoff = 0.30, &)
335
- votes = vote(doc, cutoff, &)
336
- votes_sum = votes.values.sum
337
- return [nil, nil] if votes_sum.zero?
451
+ def classify_with_confidence(doc, cutoff = 0.30, &block)
452
+ validate_cutoff!(cutoff)
338
453
 
339
- ranking = votes.keys.sort_by { |x| votes[x] }
340
- winner = ranking[-1]
341
- vote_share = votes[winner] / votes_sum.to_f
342
- [winner, vote_share]
454
+ synchronize do
455
+ votes = vote_unlocked(doc, cutoff, &block)
456
+ votes_sum = votes.values.sum
457
+ return [nil, nil] if votes_sum.zero?
458
+
459
+ ranking = votes.keys.sort_by { |x| votes[x] }
460
+ winner = ranking[-1]
461
+ vote_share = votes[winner] / votes_sum.to_f
462
+ [winner, vote_share]
463
+ end
343
464
  end
344
465
 
345
466
  # Prototype, only works on indexed documents.
@@ -347,45 +468,446 @@ module Classifier
347
468
  # it's supposed to.
348
469
  # @rbs (String, ?Integer) -> Array[Symbol]
349
470
  def highest_ranked_stems(doc, count = 3)
350
- raise 'Requested stem ranking on non-indexed content!' unless @items[doc]
471
+ synchronize do
472
+ raise 'Requested stem ranking on non-indexed content!' unless @items[doc]
351
473
 
352
- arr = node_for_content(doc).lsi_vector.to_a
353
- top_n = arr.sort.reverse[0..(count - 1)]
354
- top_n.collect { |x| @word_list.word_for_index(arr.index(x)) }
474
+ arr = node_for_content_unlocked(doc).lsi_vector.to_a
475
+ top_n = arr.sort.reverse[0..(count - 1)]
476
+ top_n.collect { |x| @word_list.word_for_index(arr.index(x)) }
477
+ end
355
478
  end
356
479
 
357
- private
480
+ # Custom marshal serialization to exclude mutex state
481
+ # @rbs () -> Array[untyped]
482
+ def marshal_dump
483
+ [@auto_rebuild, @word_list, @items, @version, @built_at_version, @dirty]
484
+ end
358
485
 
359
- # @rbs (untyped, ?Float) -> untyped
360
- def build_reduced_matrix(matrix, cutoff = 0.75)
361
- # TODO: Check that M>=N on these dimensions! Transpose helps assure this
362
- u, v, s = matrix.SV_decomp
486
+ # Custom marshal deserialization to recreate mutex
487
+ # @rbs (Array[untyped]) -> void
488
+ def marshal_load(data)
489
+ mu_initialize
490
+ @auto_rebuild, @word_list, @items, @version, @built_at_version, @dirty = data
491
+ @storage = nil
492
+ end
363
493
 
364
- # TODO: Better than 75% term, please. :\
365
- s_cutoff = s.sort.reverse[(s.size * cutoff).round - 1]
366
- s.size.times do |ord|
367
- s[ord] = 0.0 if s[ord] < s_cutoff
494
+ # Returns a hash representation of the LSI index.
495
+ # Only source data (word_hash, categories) is included, not computed vectors.
496
+ # This can be converted to JSON or used directly.
497
+ #
498
+ # @rbs () -> untyped
499
+ def as_json(*)
500
+ items_data = @items.transform_values do |node|
501
+ {
502
+ word_hash: node.word_hash.transform_keys(&:to_s),
503
+ categories: node.categories.map(&:to_s)
504
+ }
505
+ end
506
+
507
+ {
508
+ version: 1,
509
+ type: 'lsi',
510
+ auto_rebuild: @auto_rebuild,
511
+ items: items_data
512
+ }
513
+ end
514
+
515
+ # Serializes the LSI index to a JSON string.
516
+ # Only source data (word_hash, categories) is serialized, not computed vectors.
517
+ # On load, the index will be rebuilt automatically.
518
+ #
519
+ # @rbs () -> String
520
+ def to_json(*)
521
+ as_json.to_json
522
+ end
523
+
524
+ # Loads an LSI index from a JSON string or Hash created by #to_json or #as_json.
525
+ # The index will be rebuilt after loading.
526
+ #
527
+ # @rbs (String | Hash[String, untyped]) -> LSI
528
+ def self.from_json(json)
529
+ data = json.is_a?(String) ? JSON.parse(json) : json
530
+ raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'lsi'
531
+
532
+ # Create instance with auto_rebuild disabled during loading
533
+ instance = new(auto_rebuild: false)
534
+
535
+ # Restore items (categories stay as strings, matching original storage)
536
+ data['items'].each do |item_key, item_data|
537
+ word_hash = item_data['word_hash'].transform_keys(&:to_sym)
538
+ categories = item_data['categories']
539
+ instance.instance_variable_get(:@items)[item_key] = ContentNode.new(word_hash, *categories)
540
+ instance.instance_variable_set(:@version, instance.instance_variable_get(:@version) + 1)
368
541
  end
369
- # Reconstruct the term document matrix, only with reduced rank
370
- result = u * (self.class.gsl_available ? GSL::Matrix : ::Matrix).diag(s) * v.trans
371
542
 
372
- # Native Ruby SVD returns transposed dimensions when row_size < column_size
373
- # Ensure result matches input dimensions
374
- result = result.trans if !self.class.gsl_available && result.row_size != matrix.row_size
543
+ # Restore auto_rebuild setting and rebuild index
544
+ instance.auto_rebuild = data['auto_rebuild']
545
+ instance.build_index
546
+ instance
547
+ end
548
+
549
+ # Saves the LSI index to the configured storage.
550
+ # Raises ArgumentError if no storage is configured.
551
+ #
552
+ # @rbs () -> void
553
+ def save
554
+ raise ArgumentError, 'No storage configured. Use save_to_file(path) or set storage=' unless storage
555
+
556
+ storage.write(to_json)
557
+ @dirty = false
558
+ end
375
559
 
560
+ # Saves the LSI index to a file (legacy API).
561
+ #
562
+ # @rbs (String) -> Integer
563
+ def save_to_file(path)
564
+ result = File.write(path, to_json)
565
+ @dirty = false
376
566
  result
377
567
  end
378
568
 
569
+ # Reloads the LSI index from the configured storage.
570
+ # Raises UnsavedChangesError if there are unsaved changes.
571
+ # Use reload! to force reload and discard changes.
572
+ #
573
+ # @rbs () -> self
574
+ def reload
575
+ raise ArgumentError, 'No storage configured' unless storage
576
+ raise UnsavedChangesError, 'Unsaved changes would be lost. Call save first or use reload!' if @dirty
577
+
578
+ data = storage.read
579
+ raise StorageError, 'No saved state found' unless data
580
+
581
+ restore_from_json(data)
582
+ @dirty = false
583
+ self
584
+ end
585
+
586
+ # Force reloads the LSI index from storage, discarding any unsaved changes.
587
+ #
588
+ # @rbs () -> self
589
+ def reload!
590
+ raise ArgumentError, 'No storage configured' unless storage
591
+
592
+ data = storage.read
593
+ raise StorageError, 'No saved state found' unless data
594
+
595
+ restore_from_json(data)
596
+ @dirty = false
597
+ self
598
+ end
599
+
600
+ # Returns true if there are unsaved changes.
601
+ #
602
+ # @rbs () -> bool
603
+ def dirty?
604
+ @dirty
605
+ end
606
+
607
+ # Loads an LSI index from the configured storage.
608
+ # The storage is set on the returned instance.
609
+ #
610
+ # @rbs (storage: Storage::Base) -> LSI
611
+ def self.load(storage:)
612
+ data = storage.read
613
+ raise StorageError, 'No saved state found' unless data
614
+
615
+ instance = from_json(data)
616
+ instance.storage = storage
617
+ instance
618
+ end
619
+
620
+ # Loads an LSI index from a file (legacy API).
621
+ #
622
+ # @rbs (String) -> LSI
623
+ def self.load_from_file(path)
624
+ from_json(File.read(path))
625
+ end
626
+
627
+ # Loads an LSI index from a checkpoint.
628
+ #
629
+ # @rbs (storage: Storage::Base, checkpoint_id: String) -> LSI
630
+ def self.load_checkpoint(storage:, checkpoint_id:)
631
+ raise ArgumentError, 'Storage must be File storage for checkpoints' unless storage.is_a?(Storage::File)
632
+
633
+ dir = File.dirname(storage.path)
634
+ base = File.basename(storage.path, '.*')
635
+ ext = File.extname(storage.path)
636
+ checkpoint_path = File.join(dir, "#{base}_checkpoint_#{checkpoint_id}#{ext}")
637
+
638
+ checkpoint_storage = Storage::File.new(path: checkpoint_path)
639
+ instance = load(storage: checkpoint_storage)
640
+ instance.storage = storage
641
+ instance
642
+ end
643
+
644
+ # Trains the LSI index from an IO stream.
645
+ # Each line in the stream is treated as a separate document.
646
+ # Documents are added without rebuilding, then the index is rebuilt at the end.
647
+ #
648
+ # @example Train from a file
649
+ # lsi.train_from_stream(:category, File.open('corpus.txt'))
650
+ #
651
+ # @example With progress tracking
652
+ # lsi.train_from_stream(:category, io, batch_size: 500) do |progress|
653
+ # puts "#{progress.completed} documents processed"
654
+ # end
655
+ #
656
+ # @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
657
+ def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE)
658
+ original_auto_rebuild = @auto_rebuild
659
+ @auto_rebuild = false
660
+
661
+ begin
662
+ reader = Streaming::LineReader.new(io, batch_size: batch_size)
663
+ total = reader.estimate_line_count
664
+ progress = Streaming::Progress.new(total: total)
665
+
666
+ reader.each_batch do |batch|
667
+ batch.each { |text| add_item(text, category) }
668
+ progress.completed += batch.size
669
+ progress.current_batch += 1
670
+ yield progress if block_given?
671
+ end
672
+ ensure
673
+ @auto_rebuild = original_auto_rebuild
674
+ build_index if original_auto_rebuild
675
+ end
676
+ end
677
+
678
+ # Adds items to the index in batches from an array.
679
+ # Documents are added without rebuilding, then the index is rebuilt at the end.
680
+ #
681
+ # @example Batch add with progress
682
+ # lsi.add_batch(Dog: documents, batch_size: 100) do |progress|
683
+ # puts "#{progress.percent}% complete"
684
+ # end
685
+ #
686
+ # @rbs (?batch_size: Integer, **Array[String]) { (Streaming::Progress) -> void } -> void
687
+ def add_batch(batch_size: Streaming::DEFAULT_BATCH_SIZE, **items)
688
+ original_auto_rebuild = @auto_rebuild
689
+ @auto_rebuild = false
690
+
691
+ begin
692
+ total_docs = items.values.sum { |v| Array(v).size }
693
+ progress = Streaming::Progress.new(total: total_docs)
694
+
695
+ items.each do |category, documents|
696
+ Array(documents).each_slice(batch_size) do |batch|
697
+ batch.each { |doc| add_item(doc, category.to_s) }
698
+ progress.completed += batch.size
699
+ progress.current_batch += 1
700
+ yield progress if block_given?
701
+ end
702
+ end
703
+ ensure
704
+ @auto_rebuild = original_auto_rebuild
705
+ build_index if original_auto_rebuild
706
+ end
707
+ end
708
+
709
+ # Alias train_batch to add_batch for API consistency with other classifiers.
710
+ # Note: LSI uses categories differently (items have categories, not the training call).
711
+ #
712
+ # @rbs (?(String | Symbol)?, ?Array[String]?, ?batch_size: Integer, **Array[String]) { (Streaming::Progress) -> void } -> void
713
+ def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block)
714
+ if category && documents
715
+ add_batch(batch_size: batch_size, **{ category.to_sym => documents }, &block)
716
+ else
717
+ add_batch(batch_size: batch_size, **categories, &block)
718
+ end
719
+ end
720
+
721
+ private
722
+
723
+ # Restores LSI state from a JSON string (used by reload)
724
+ # @rbs (String) -> void
725
+ def restore_from_json(json)
726
+ data = JSON.parse(json)
727
+ raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'lsi'
728
+
729
+ synchronize do
730
+ # Recreate the items
731
+ @items = {}
732
+ data['items'].each do |item_key, item_data|
733
+ word_hash = item_data['word_hash'].transform_keys(&:to_sym)
734
+ categories = item_data['categories']
735
+ @items[item_key] = ContentNode.new(word_hash, *categories)
736
+ end
737
+
738
+ # Restore settings
739
+ @auto_rebuild = data['auto_rebuild']
740
+ @version += 1
741
+ @built_at_version = -1
742
+ @word_list = WordList.new
743
+ @dirty = false
744
+ end
745
+
746
+ # Rebuild the index
747
+ build_index
748
+ end
749
+
750
+ # @rbs (Float) -> void
751
+ def validate_cutoff!(cutoff)
752
+ return if cutoff.positive? && cutoff < 1
753
+
754
+ raise ArgumentError, "cutoff must be between 0 and 1 (exclusive), got #{cutoff}"
755
+ end
756
+
757
+ # Assigns LSI vectors using native C extension
758
+ # @rbs (untyped, Array[ContentNode]) -> void
759
+ def assign_native_ext_lsi_vectors(ntdm, doc_list)
760
+ ntdm.size[1].times do |col|
761
+ vec = self.class.vector_class.alloc(ntdm.column(col).to_a).row
762
+ doc_list[col].lsi_vector = vec
763
+ doc_list[col].lsi_norm = vec.normalize
764
+ end
765
+ end
766
+
767
+ # Assigns LSI vectors using pure Ruby Matrix
768
+ # @rbs (untyped, Array[ContentNode]) -> void
769
+ def assign_ruby_lsi_vectors(ntdm, doc_list)
770
+ ntdm.column_size.times do |col|
771
+ next unless doc_list[col]
772
+
773
+ column = ntdm.column(col)
774
+ next unless column
775
+
776
+ doc_list[col].lsi_vector = column
777
+ doc_list[col].lsi_norm = column.normalize
778
+ end
779
+ end
780
+
781
+ # Unlocked version of needs_rebuild? for internal use when lock is already held
782
+ # @rbs () -> bool
783
+ def needs_rebuild_unlocked?
784
+ (@items.keys.size > 1) && (@version != @built_at_version)
785
+ end
786
+
787
+ # Unlocked version of proximity_array_for_content for internal use
788
+ # @rbs (String) ?{ (String) -> String } -> Array[[String, Float]]
789
+ def proximity_array_for_content_unlocked(doc, &)
790
+ return [] if needs_rebuild_unlocked?
791
+ return @items.keys.map { |item| [item, 1.0] } if @items.size == 1
792
+
793
+ content_node = node_for_content_unlocked(doc, &)
794
+ result =
795
+ @items.keys.collect do |item|
796
+ val = if self.class.native_available?
797
+ content_node.search_vector * @items[item].search_vector.col
798
+ else
799
+ (Matrix[content_node.search_vector] * @items[item].search_vector)[0]
800
+ end
801
+ [item, val]
802
+ end
803
+ result.sort_by { |x| x[1] }.reverse
804
+ end
805
+
806
+ # Unlocked version of proximity_norms_for_content for internal use
807
+ # @rbs (String) ?{ (String) -> String } -> Array[[String, Float]]
808
+ def proximity_norms_for_content_unlocked(doc, &)
809
+ return [] if needs_rebuild_unlocked?
810
+
811
+ content_node = node_for_content_unlocked(doc, &)
812
+ result =
813
+ @items.keys.collect do |item|
814
+ val = if self.class.native_available?
815
+ content_node.search_norm * @items[item].search_norm.col
816
+ else
817
+ (Matrix[content_node.search_norm] * @items[item].search_norm)[0]
818
+ end
819
+ [item, val]
820
+ end
821
+ result.sort_by { |x| x[1] }.reverse
822
+ end
823
+
824
+ # Unlocked version of vote for internal use
825
+ # @rbs (String, ?Float) ?{ (String) -> String } -> Hash[String | Symbol, Float]
826
+ def vote_unlocked(doc, cutoff = 0.30, &)
827
+ icutoff = (@items.size * cutoff).round
828
+ carry = proximity_array_for_content_unlocked(doc, &)
829
+ carry = carry[0..(icutoff - 1)]
830
+ votes = {}
831
+ carry.each do |pair|
832
+ categories = @items[pair[0]].categories
833
+ categories.each do |category|
834
+ votes[category] ||= 0.0
835
+ votes[category] += pair[1]
836
+ end
837
+ end
838
+ votes
839
+ end
840
+
841
+ # Unlocked version of node_for_content for internal use.
379
842
  # @rbs (String) ?{ (String) -> String } -> ContentNode
380
- def node_for_content(item, &block)
843
+ def node_for_content_unlocked(item, &block)
381
844
  return @items[item] if @items[item]
382
845
 
383
846
  clean_word_hash = block ? block.call(item).clean_word_hash : item.to_s.clean_word_hash
384
847
  cn = ContentNode.new(clean_word_hash, &block)
385
- cn.raw_vector_with(@word_list) unless needs_rebuild?
848
+ cn.raw_vector_with(@word_list) unless needs_rebuild_unlocked?
849
+ assign_lsi_vector_incremental(cn) if incremental_enabled?
386
850
  cn
387
851
  end
388
852
 
853
+ # @rbs (untyped, ?Float) -> untyped
854
+ def build_reduced_matrix(matrix, cutoff = 0.75)
855
+ result, _u = build_reduced_matrix_with_u(matrix, cutoff)
856
+ result
857
+ end
858
+
859
+ # Builds reduced matrix and returns both the result and the U matrix.
860
+ # U matrix is needed for incremental SVD updates.
861
+ # @rbs (untyped, ?Float) -> [untyped, Matrix]
862
+ def build_reduced_matrix_with_u(matrix, cutoff = 0.75)
863
+ u, v, s = matrix.SV_decomp
864
+
865
+ all_singular_values = s.sort.reverse
866
+ s_cutoff_index = [(s.size * cutoff).round - 1, 0].max
867
+ s_cutoff = all_singular_values[s_cutoff_index]
868
+
869
+ kept_indices = []
870
+ kept_singular_values = []
871
+ s.size.times do |ord|
872
+ if s[ord] >= s_cutoff
873
+ kept_indices << ord
874
+ kept_singular_values << s[ord]
875
+ else
876
+ s[ord] = 0.0
877
+ end
878
+ end
879
+
880
+ @singular_values = kept_singular_values.sort.reverse
881
+ result = u * self.class.matrix_class.diag(s) * v.trans
882
+ result = result.trans if result.row_size != matrix.row_size
883
+ u_reduced = extract_reduced_u(u, kept_indices, s)
884
+
885
+ [result, u_reduced]
886
+ end
887
+
888
+ # Extracts columns from U corresponding to kept singular values.
889
+ # Columns are sorted by descending singular value to match @singular_values order.
890
+ # rubocop:disable Naming/MethodParameterName
891
+ # @rbs (untyped, Array[Integer], Array[Float]) -> Matrix
892
+ def extract_reduced_u(u, kept_indices, singular_values)
893
+ return Matrix.empty(u.row_size, 0) if kept_indices.empty?
894
+
895
+ sorted_indices = kept_indices.sort_by { |i| -singular_values[i] }
896
+
897
+ if u.respond_to?(:to_ruby_matrix)
898
+ u = u.to_ruby_matrix
899
+ elsif !u.is_a?(::Matrix)
900
+ rows = u.row_size.times.map do |i|
901
+ sorted_indices.map { |j| u[i, j] }
902
+ end
903
+ return Matrix.rows(rows)
904
+ end
905
+
906
+ cols = sorted_indices.map { |i| u.column(i).to_a }
907
+ Matrix.columns(cols)
908
+ end
909
+ # rubocop:enable Naming/MethodParameterName
910
+
389
911
  # @rbs () -> void
390
912
  def make_word_list
391
913
  @word_list = WordList.new
@@ -393,5 +915,129 @@ module Classifier
393
915
  node.word_hash.each_key { |key| @word_list.add_word key }
394
916
  end
395
917
  end
918
+
919
+ # Performs incremental SVD update for a new document.
920
+ # @rbs (ContentNode, Hash[Symbol, Integer]) -> void
921
+ def perform_incremental_update(node, word_hash)
922
+ needs_full_rebuild = false
923
+ old_rank = nil
924
+
925
+ synchronize do
926
+ if vocabulary_growth_exceeds_threshold?(word_hash)
927
+ disable_incremental_mode!
928
+ needs_full_rebuild = true
929
+ next
930
+ end
931
+
932
+ old_rank = @u_matrix.column_size
933
+ extend_vocabulary_for_incremental(word_hash)
934
+ raw_vec = node.raw_vector_with(@word_list)
935
+ raw_vector = Vector[*raw_vec.to_a]
936
+
937
+ @u_matrix, @singular_values = IncrementalSVD.update(
938
+ @u_matrix, @singular_values, raw_vector, max_rank: @max_rank
939
+ )
940
+
941
+ new_rank = @u_matrix.column_size
942
+ if new_rank > old_rank
943
+ reproject_all_documents
944
+ else
945
+ assign_lsi_vector_incremental(node)
946
+ end
947
+
948
+ @built_at_version = @version
949
+ end
950
+
951
+ build_index if needs_full_rebuild
952
+ end
953
+
954
+ # Checks if vocabulary growth would exceed threshold (20%)
955
+ # @rbs (Hash[Symbol, Integer]) -> bool
956
+ def vocabulary_growth_exceeds_threshold?(word_hash)
957
+ return false unless @initial_vocab_size&.positive?
958
+
959
+ new_words = word_hash.keys.count { |w| @word_list[w].nil? }
960
+ growth_ratio = new_words.to_f / @initial_vocab_size
961
+ growth_ratio > 0.2
962
+ end
963
+
964
+ # Extends vocabulary and U matrix for new words.
965
+ # @rbs (Hash[Symbol, Integer]) -> void
966
+ def extend_vocabulary_for_incremental(word_hash)
967
+ new_words = word_hash.keys.select { |w| @word_list[w].nil? }
968
+ return if new_words.empty?
969
+
970
+ new_words.each { |word| @word_list.add_word(word) }
971
+ extend_u_matrix(new_words.size)
972
+ end
973
+
974
+ # Extends U matrix with zero rows for new vocabulary terms.
975
+ # @rbs (Integer) -> void
976
+ def extend_u_matrix(num_new_rows)
977
+ return if num_new_rows.zero? || @u_matrix.nil?
978
+
979
+ if self.class.native_available? && @u_matrix.is_a?(self.class.matrix_class)
980
+ new_rows = self.class.matrix_class.zeros(num_new_rows, @u_matrix.column_size)
981
+ @u_matrix = self.class.matrix_class.vstack(@u_matrix, new_rows)
982
+ else
983
+ new_rows = Matrix.zero(num_new_rows, @u_matrix.column_size)
984
+ @u_matrix = Matrix.vstack(@u_matrix, new_rows)
985
+ end
986
+ end
987
+
988
+ # Re-projects all documents onto the current U matrix
989
+ # Called when rank grows to ensure consistent LSI vector sizes
990
+ # Uses native batch_project for performance when available
991
+ # @rbs () -> void
992
+ def reproject_all_documents
993
+ return unless @u_matrix
994
+ return reproject_all_documents_native if self.class.native_available? && @u_matrix.respond_to?(:batch_project)
995
+
996
+ reproject_all_documents_ruby
997
+ end
998
+
999
+ # Native batch re-projection using C extension.
1000
+ # @rbs () -> void
1001
+ def reproject_all_documents_native
1002
+ nodes = @items.values
1003
+ raw_vectors = nodes.map do |node|
1004
+ raw = node.raw_vector_with(@word_list)
1005
+ raw.is_a?(self.class.vector_class) ? raw : self.class.vector_class.alloc(raw.to_a)
1006
+ end
1007
+
1008
+ lsi_vectors = @u_matrix.batch_project(raw_vectors)
1009
+
1010
+ nodes.each_with_index do |node, i|
1011
+ lsi_vec = lsi_vectors[i].row
1012
+ node.lsi_vector = lsi_vec
1013
+ node.lsi_norm = lsi_vec.normalize
1014
+ end
1015
+ end
1016
+
1017
+ # Pure Ruby re-projection (fallback)
1018
+ # @rbs () -> void
1019
+ def reproject_all_documents_ruby
1020
+ @items.each_value do |node|
1021
+ assign_lsi_vector_incremental(node)
1022
+ end
1023
+ end
1024
+
1025
+ # Assigns LSI vector to a node using projection: lsi_vec = U^T * raw_vec.
1026
+ # @rbs (ContentNode) -> void
1027
+ def assign_lsi_vector_incremental(node)
1028
+ return unless @u_matrix
1029
+
1030
+ raw_vec = node.raw_vector_with(@word_list)
1031
+ raw_vector = Vector[*raw_vec.to_a]
1032
+ lsi_arr = (@u_matrix.transpose * raw_vector).to_a
1033
+
1034
+ lsi_vec = if self.class.native_available?
1035
+ self.class.vector_class.alloc(lsi_arr).row
1036
+ else
1037
+ Vector[*lsi_arr]
1038
+ end
1039
+ node.lsi_vector = lsi_vec
1040
+ node.lsi_norm = lsi_vec.normalize
1041
+ end
396
1042
  end
397
1043
  end