classifier 2.0.0 → 2.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CLAUDE.md +23 -13
- data/README.md +72 -190
- data/ext/classifier/classifier_ext.c +26 -0
- data/ext/classifier/extconf.rb +15 -0
- data/ext/classifier/incremental_svd.c +393 -0
- data/ext/classifier/linalg.h +72 -0
- data/ext/classifier/matrix.c +387 -0
- data/ext/classifier/svd.c +208 -0
- data/ext/classifier/vector.c +319 -0
- data/lib/classifier/bayes.rb +398 -54
- data/lib/classifier/errors.rb +19 -0
- data/lib/classifier/extensions/vector.rb +12 -4
- data/lib/classifier/knn.rb +351 -0
- data/lib/classifier/logistic_regression.rb +571 -0
- data/lib/classifier/lsi/content_node.rb +5 -5
- data/lib/classifier/lsi/incremental_svd.rb +166 -0
- data/lib/classifier/lsi/summary.rb +25 -5
- data/lib/classifier/lsi.rb +784 -138
- data/lib/classifier/storage/base.rb +50 -0
- data/lib/classifier/storage/file.rb +51 -0
- data/lib/classifier/storage/memory.rb +49 -0
- data/lib/classifier/storage.rb +9 -0
- data/lib/classifier/streaming/line_reader.rb +99 -0
- data/lib/classifier/streaming/progress.rb +96 -0
- data/lib/classifier/streaming.rb +122 -0
- data/lib/classifier/tfidf.rb +408 -0
- data/lib/classifier.rb +6 -0
- data/sig/vendor/json.rbs +4 -0
- data/sig/vendor/matrix.rbs +25 -14
- data/sig/vendor/mutex_m.rbs +16 -0
- data/sig/vendor/streaming.rbs +14 -0
- data/test/test_helper.rb +2 -0
- metadata +52 -8
- data/lib/classifier/extensions/vector_serialize.rb +0 -18
data/lib/classifier/lsi.rb
CHANGED
|
@@ -6,61 +6,110 @@
|
|
|
6
6
|
|
|
7
7
|
module Classifier
|
|
8
8
|
class LSI
|
|
9
|
-
#
|
|
10
|
-
@
|
|
9
|
+
# Backend options: :native, :ruby
|
|
10
|
+
# @rbs @backend: Symbol
|
|
11
|
+
@backend = :ruby
|
|
11
12
|
|
|
12
13
|
class << self
|
|
13
|
-
# @rbs @
|
|
14
|
-
attr_accessor :
|
|
14
|
+
# @rbs @backend: Symbol
|
|
15
|
+
attr_accessor :backend
|
|
16
|
+
|
|
17
|
+
# Check if using native C extension
|
|
18
|
+
# @rbs () -> bool
|
|
19
|
+
def native_available?
|
|
20
|
+
backend == :native
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
# Get the Vector class for the current backend
|
|
24
|
+
# @rbs () -> Class
|
|
25
|
+
def vector_class
|
|
26
|
+
backend == :native ? Classifier::Linalg::Vector : ::Vector
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
# Get the Matrix class for the current backend
|
|
30
|
+
# @rbs () -> Class
|
|
31
|
+
def matrix_class
|
|
32
|
+
backend == :native ? Classifier::Linalg::Matrix : ::Matrix
|
|
33
|
+
end
|
|
15
34
|
end
|
|
16
35
|
end
|
|
17
36
|
end
|
|
18
37
|
|
|
38
|
+
# Backend detection: native extension > pure Ruby
|
|
39
|
+
# Set NATIVE_VECTOR=true to force pure Ruby implementation
|
|
40
|
+
|
|
19
41
|
begin
|
|
20
|
-
# to test the native vector class, try `rake test NATIVE_VECTOR=true`
|
|
21
42
|
raise LoadError if ENV['NATIVE_VECTOR'] == 'true'
|
|
22
|
-
raise LoadError unless Gem::Specification.find_all_by_name('gsl').any?
|
|
23
43
|
|
|
24
|
-
require '
|
|
25
|
-
|
|
26
|
-
Classifier::LSI.gsl_available = true
|
|
44
|
+
require 'classifier/classifier_ext'
|
|
45
|
+
Classifier::LSI.backend = :native
|
|
27
46
|
rescue LoadError
|
|
28
|
-
|
|
29
|
-
|
|
47
|
+
# Fall back to pure Ruby implementation
|
|
48
|
+
unless ENV['SUPPRESS_LSI_WARNING'] == 'true'
|
|
49
|
+
warn 'Notice: for 5-10x faster LSI, install the classifier gem with native extensions. ' \
|
|
50
|
+
'Set SUPPRESS_LSI_WARNING=true to hide this.'
|
|
30
51
|
end
|
|
31
|
-
Classifier::LSI.
|
|
52
|
+
Classifier::LSI.backend = :ruby
|
|
32
53
|
require 'classifier/extensions/vector'
|
|
33
54
|
end
|
|
34
55
|
|
|
56
|
+
require 'json'
|
|
57
|
+
require 'mutex_m'
|
|
35
58
|
require 'classifier/lsi/word_list'
|
|
36
59
|
require 'classifier/lsi/content_node'
|
|
37
60
|
require 'classifier/lsi/summary'
|
|
61
|
+
require 'classifier/lsi/incremental_svd'
|
|
38
62
|
|
|
39
63
|
module Classifier
|
|
40
64
|
# This class implements a Latent Semantic Indexer, which can search, classify and cluster
|
|
41
65
|
# data based on underlying semantic relations. For more information on the algorithms used,
|
|
42
66
|
# please consult Wikipedia[http://en.wikipedia.org/wiki/Latent_Semantic_Indexing].
|
|
43
67
|
class LSI
|
|
68
|
+
include Mutex_m
|
|
69
|
+
include Streaming
|
|
70
|
+
|
|
44
71
|
# @rbs @auto_rebuild: bool
|
|
45
72
|
# @rbs @word_list: WordList
|
|
46
73
|
# @rbs @items: Hash[untyped, ContentNode]
|
|
47
74
|
# @rbs @version: Integer
|
|
48
75
|
# @rbs @built_at_version: Integer
|
|
76
|
+
# @rbs @singular_values: Array[Float]?
|
|
77
|
+
# @rbs @dirty: bool
|
|
78
|
+
# @rbs @storage: Storage::Base?
|
|
79
|
+
# @rbs @incremental_mode: bool
|
|
80
|
+
# @rbs @u_matrix: Matrix?
|
|
81
|
+
# @rbs @max_rank: Integer
|
|
82
|
+
# @rbs @initial_vocab_size: Integer?
|
|
49
83
|
|
|
50
|
-
attr_reader :word_list
|
|
51
|
-
attr_accessor :auto_rebuild
|
|
84
|
+
attr_reader :word_list, :singular_values
|
|
85
|
+
attr_accessor :auto_rebuild, :storage
|
|
86
|
+
|
|
87
|
+
# Default maximum rank for incremental SVD
|
|
88
|
+
DEFAULT_MAX_RANK = 100
|
|
52
89
|
|
|
53
90
|
# Create a fresh index.
|
|
54
91
|
# If you want to call #build_index manually, use
|
|
55
92
|
# Classifier::LSI.new auto_rebuild: false
|
|
56
93
|
#
|
|
94
|
+
# For incremental SVD mode (adds documents without full rebuild):
|
|
95
|
+
# Classifier::LSI.new incremental: true, max_rank: 100
|
|
96
|
+
#
|
|
57
97
|
# @rbs (?Hash[Symbol, untyped]) -> void
|
|
58
98
|
def initialize(options = {})
|
|
99
|
+
super()
|
|
59
100
|
@auto_rebuild = true unless options[:auto_rebuild] == false
|
|
60
101
|
@word_list = WordList.new
|
|
61
102
|
@items = {}
|
|
62
103
|
@version = 0
|
|
63
104
|
@built_at_version = -1
|
|
105
|
+
@dirty = false
|
|
106
|
+
@storage = nil
|
|
107
|
+
|
|
108
|
+
# Incremental SVD settings
|
|
109
|
+
@incremental_mode = options[:incremental] == true
|
|
110
|
+
@max_rank = options[:max_rank] || DEFAULT_MAX_RANK
|
|
111
|
+
@u_matrix = nil
|
|
112
|
+
@initial_vocab_size = nil
|
|
64
113
|
end
|
|
65
114
|
|
|
66
115
|
# Returns true if the index needs to be rebuilt. The index needs
|
|
@@ -69,7 +118,85 @@ module Classifier
|
|
|
69
118
|
#
|
|
70
119
|
# @rbs () -> bool
|
|
71
120
|
def needs_rebuild?
|
|
72
|
-
(@items.keys.size > 1) && (@version != @built_at_version)
|
|
121
|
+
synchronize { (@items.keys.size > 1) && (@version != @built_at_version) }
|
|
122
|
+
end
|
|
123
|
+
|
|
124
|
+
# @rbs () -> Array[Hash[Symbol, untyped]]?
|
|
125
|
+
def singular_value_spectrum
|
|
126
|
+
return nil unless @singular_values
|
|
127
|
+
|
|
128
|
+
total = @singular_values.sum
|
|
129
|
+
return nil if total.zero?
|
|
130
|
+
|
|
131
|
+
cumulative = 0.0
|
|
132
|
+
@singular_values.map.with_index do |value, i|
|
|
133
|
+
cumulative += value
|
|
134
|
+
{
|
|
135
|
+
dimension: i,
|
|
136
|
+
value: value,
|
|
137
|
+
percentage: value / total,
|
|
138
|
+
cumulative_percentage: cumulative / total
|
|
139
|
+
}
|
|
140
|
+
end
|
|
141
|
+
end
|
|
142
|
+
|
|
143
|
+
# Returns true if incremental mode is enabled and active.
|
|
144
|
+
# Incremental mode becomes active after the first build_index call.
|
|
145
|
+
#
|
|
146
|
+
# @rbs () -> bool
|
|
147
|
+
def incremental_enabled?
|
|
148
|
+
@incremental_mode && !@u_matrix.nil?
|
|
149
|
+
end
|
|
150
|
+
|
|
151
|
+
# Returns the current rank of the incremental SVD (number of singular values kept).
|
|
152
|
+
# Returns nil if incremental mode is not active.
|
|
153
|
+
#
|
|
154
|
+
# @rbs () -> Integer?
|
|
155
|
+
def current_rank
|
|
156
|
+
@singular_values&.count(&:positive?)
|
|
157
|
+
end
|
|
158
|
+
|
|
159
|
+
# Disables incremental mode. Subsequent adds will trigger full rebuilds.
|
|
160
|
+
#
|
|
161
|
+
# @rbs () -> void
|
|
162
|
+
def disable_incremental_mode!
|
|
163
|
+
@incremental_mode = false
|
|
164
|
+
@u_matrix = nil
|
|
165
|
+
@initial_vocab_size = nil
|
|
166
|
+
end
|
|
167
|
+
|
|
168
|
+
# Enables incremental mode with optional max_rank setting.
|
|
169
|
+
# The next build_index call will store the U matrix for incremental updates.
|
|
170
|
+
#
|
|
171
|
+
# @rbs (?max_rank: Integer) -> void
|
|
172
|
+
def enable_incremental_mode!(max_rank: DEFAULT_MAX_RANK)
|
|
173
|
+
@incremental_mode = true
|
|
174
|
+
@max_rank = max_rank
|
|
175
|
+
end
|
|
176
|
+
|
|
177
|
+
# Adds items to the index using hash-style syntax.
|
|
178
|
+
# The hash keys are categories, and values are items (or arrays of items).
|
|
179
|
+
#
|
|
180
|
+
# For example:
|
|
181
|
+
# lsi = Classifier::LSI.new
|
|
182
|
+
# lsi.add("Dog" => "Dogs are loyal pets")
|
|
183
|
+
# lsi.add("Cat" => "Cats are independent")
|
|
184
|
+
# lsi.add(Bird: "Birds can fly") # Symbol keys work too
|
|
185
|
+
#
|
|
186
|
+
# Multiple items with the same category:
|
|
187
|
+
# lsi.add("Dog" => ["Dogs are loyal", "Puppies are cute"])
|
|
188
|
+
#
|
|
189
|
+
# Batch operations with multiple categories:
|
|
190
|
+
# lsi.add(
|
|
191
|
+
# "Dog" => ["Dogs are loyal", "Puppies are cute"],
|
|
192
|
+
# "Cat" => ["Cats are independent", "Kittens are playful"]
|
|
193
|
+
# )
|
|
194
|
+
#
|
|
195
|
+
# @rbs (**untyped items) -> void
|
|
196
|
+
def add(**items)
|
|
197
|
+
items.each do |category, value|
|
|
198
|
+
Array(value).each { |doc| add_item(doc, category.to_s) }
|
|
199
|
+
end
|
|
73
200
|
end
|
|
74
201
|
|
|
75
202
|
# Adds an item to the index. item is assumed to be a string, but
|
|
@@ -78,6 +205,8 @@ module Classifier
|
|
|
78
205
|
# fetch fresh string data. This optional block is passed the item,
|
|
79
206
|
# so the item may only be a reference to a URL or file name.
|
|
80
207
|
#
|
|
208
|
+
# @deprecated Use {#add} instead for clearer hash-style syntax.
|
|
209
|
+
#
|
|
81
210
|
# For example:
|
|
82
211
|
# lsi = Classifier::LSI.new
|
|
83
212
|
# lsi.add_item "This is just plain text"
|
|
@@ -88,8 +217,18 @@ module Classifier
|
|
|
88
217
|
# @rbs (String, *String | Symbol) ?{ (String) -> String } -> void
|
|
89
218
|
def add_item(item, *categories, &block)
|
|
90
219
|
clean_word_hash = block ? block.call(item).clean_word_hash : item.to_s.clean_word_hash
|
|
91
|
-
|
|
92
|
-
|
|
220
|
+
node = nil
|
|
221
|
+
|
|
222
|
+
synchronize do
|
|
223
|
+
node = ContentNode.new(clean_word_hash, *categories)
|
|
224
|
+
@items[item] = node
|
|
225
|
+
@version += 1
|
|
226
|
+
@dirty = true
|
|
227
|
+
end
|
|
228
|
+
|
|
229
|
+
# Use incremental update if enabled and we have a U matrix
|
|
230
|
+
return perform_incremental_update(node, clean_word_hash) if @incremental_mode && @u_matrix
|
|
231
|
+
|
|
93
232
|
build_index if @auto_rebuild
|
|
94
233
|
end
|
|
95
234
|
|
|
@@ -107,25 +246,32 @@ module Classifier
|
|
|
107
246
|
#
|
|
108
247
|
# @rbs (String) -> Array[String | Symbol]
|
|
109
248
|
def categories_for(item)
|
|
110
|
-
|
|
249
|
+
synchronize do
|
|
250
|
+
return [] unless @items[item]
|
|
111
251
|
|
|
112
|
-
|
|
252
|
+
@items[item].categories
|
|
253
|
+
end
|
|
113
254
|
end
|
|
114
255
|
|
|
115
256
|
# Removes an item from the database, if it is indexed.
|
|
116
257
|
#
|
|
117
258
|
# @rbs (String) -> void
|
|
118
259
|
def remove_item(item)
|
|
119
|
-
|
|
260
|
+
removed = synchronize do
|
|
261
|
+
next false unless @items.key?(item)
|
|
120
262
|
|
|
121
|
-
|
|
122
|
-
|
|
263
|
+
@items.delete(item)
|
|
264
|
+
@version += 1
|
|
265
|
+
@dirty = true
|
|
266
|
+
true
|
|
267
|
+
end
|
|
268
|
+
build_index if removed && @auto_rebuild
|
|
123
269
|
end
|
|
124
270
|
|
|
125
271
|
# Returns an array of items that are indexed.
|
|
126
272
|
# @rbs () -> Array[untyped]
|
|
127
273
|
def items
|
|
128
|
-
@items.keys
|
|
274
|
+
synchronize { @items.keys }
|
|
129
275
|
end
|
|
130
276
|
|
|
131
277
|
# This function rebuilds the index if needs_rebuild? returns true.
|
|
@@ -143,40 +289,38 @@ module Classifier
|
|
|
143
289
|
# A value of 1 for cutoff means that no semantic analysis will take place,
|
|
144
290
|
# turning the LSI class into a simple vector search engine.
|
|
145
291
|
#
|
|
146
|
-
# @rbs (?Float) -> void
|
|
147
|
-
def build_index(cutoff = 0.75)
|
|
148
|
-
|
|
292
|
+
# @rbs (?Float, ?force: bool) -> void
|
|
293
|
+
def build_index(cutoff = 0.75, force: false)
|
|
294
|
+
validate_cutoff!(cutoff)
|
|
149
295
|
|
|
150
|
-
|
|
296
|
+
synchronize do
|
|
297
|
+
return unless force || needs_rebuild_unlocked?
|
|
151
298
|
|
|
152
|
-
|
|
153
|
-
tda = doc_list.collect { |node| node.raw_vector_with(@word_list) }
|
|
299
|
+
make_word_list
|
|
154
300
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
ntdm = build_reduced_matrix(tdm, cutoff)
|
|
301
|
+
doc_list = @items.values
|
|
302
|
+
tda = doc_list.collect { |node| node.raw_vector_with(@word_list) }
|
|
158
303
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
304
|
+
if self.class.native_available?
|
|
305
|
+
# Convert vectors to arrays for matrix construction
|
|
306
|
+
tda_arrays = tda.map { |v| v.respond_to?(:to_a) ? v.to_a : v }
|
|
307
|
+
tdm = self.class.matrix_class.alloc(*tda_arrays).trans
|
|
308
|
+
ntdm, u_mat = build_reduced_matrix_with_u(tdm, cutoff)
|
|
309
|
+
assign_native_ext_lsi_vectors(ntdm, doc_list)
|
|
310
|
+
else
|
|
311
|
+
tdm = Matrix.rows(tda).trans
|
|
312
|
+
ntdm, u_mat = build_reduced_matrix_with_u(tdm, cutoff)
|
|
313
|
+
assign_ruby_lsi_vectors(ntdm, doc_list)
|
|
163
314
|
end
|
|
164
|
-
else
|
|
165
|
-
tdm = Matrix.rows(tda).trans
|
|
166
|
-
ntdm = build_reduced_matrix(tdm, cutoff)
|
|
167
315
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
next unless column
|
|
173
|
-
|
|
174
|
-
doc_list[col].lsi_vector = column
|
|
175
|
-
doc_list[col].lsi_norm = column.normalize
|
|
316
|
+
# Store U matrix for incremental mode
|
|
317
|
+
if @incremental_mode
|
|
318
|
+
@u_matrix = u_mat
|
|
319
|
+
@initial_vocab_size = @word_list.size
|
|
176
320
|
end
|
|
177
|
-
end
|
|
178
321
|
|
|
179
|
-
|
|
322
|
+
@built_at_version = @version
|
|
323
|
+
end
|
|
180
324
|
end
|
|
181
325
|
|
|
182
326
|
# This method returns max_chunks entries, ordered by their average semantic rating.
|
|
@@ -190,12 +334,14 @@ module Classifier
|
|
|
190
334
|
#
|
|
191
335
|
# @rbs (?Integer) -> Array[String]
|
|
192
336
|
def highest_relative_content(max_chunks = 10)
|
|
193
|
-
|
|
337
|
+
synchronize do
|
|
338
|
+
return [] if needs_rebuild_unlocked?
|
|
194
339
|
|
|
195
|
-
|
|
196
|
-
|
|
340
|
+
avg_density = {}
|
|
341
|
+
@items.each_key { |x| avg_density[x] = proximity_array_for_content_unlocked(x).sum { |pair| pair[1] } }
|
|
197
342
|
|
|
198
|
-
|
|
343
|
+
avg_density.keys.sort_by { |x| avg_density[x] }.reverse[0..(max_chunks - 1)].map
|
|
344
|
+
end
|
|
199
345
|
end
|
|
200
346
|
|
|
201
347
|
# This function is the primitive that find_related and classify
|
|
@@ -212,20 +358,8 @@ module Classifier
|
|
|
212
358
|
# text data. See add_item for examples of how this works.
|
|
213
359
|
#
|
|
214
360
|
# @rbs (String) ?{ (String) -> String } -> Array[[String, Float]]
|
|
215
|
-
def proximity_array_for_content(doc, &)
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
content_node = node_for_content(doc, &)
|
|
219
|
-
result =
|
|
220
|
-
@items.keys.collect do |item|
|
|
221
|
-
val = if self.class.gsl_available
|
|
222
|
-
content_node.search_vector * @items[item].search_vector.col
|
|
223
|
-
else
|
|
224
|
-
(Matrix[content_node.search_vector] * @items[item].search_vector)[0]
|
|
225
|
-
end
|
|
226
|
-
[item, val]
|
|
227
|
-
end
|
|
228
|
-
result.sort_by { |x| x[1] }.reverse
|
|
361
|
+
def proximity_array_for_content(doc, &block)
|
|
362
|
+
synchronize { proximity_array_for_content_unlocked(doc, &block) }
|
|
229
363
|
end
|
|
230
364
|
|
|
231
365
|
# Similar to proximity_array_for_content, this function takes similar
|
|
@@ -235,20 +369,8 @@ module Classifier
|
|
|
235
369
|
# the text you're working with. search uses this primitive.
|
|
236
370
|
#
|
|
237
371
|
# @rbs (String) ?{ (String) -> String } -> Array[[String, Float]]
|
|
238
|
-
def proximity_norms_for_content(doc, &)
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
content_node = node_for_content(doc, &)
|
|
242
|
-
result =
|
|
243
|
-
@items.keys.collect do |item|
|
|
244
|
-
val = if self.class.gsl_available
|
|
245
|
-
content_node.search_norm * @items[item].search_norm.col
|
|
246
|
-
else
|
|
247
|
-
(Matrix[content_node.search_norm] * @items[item].search_norm)[0]
|
|
248
|
-
end
|
|
249
|
-
[item, val]
|
|
250
|
-
end
|
|
251
|
-
result.sort_by { |x| x[1] }.reverse
|
|
372
|
+
def proximity_norms_for_content(doc, &block)
|
|
373
|
+
synchronize { proximity_norms_for_content_unlocked(doc, &block) }
|
|
252
374
|
end
|
|
253
375
|
|
|
254
376
|
# This function allows for text-based search of your index. Unlike other functions
|
|
@@ -261,11 +383,13 @@ module Classifier
|
|
|
261
383
|
#
|
|
262
384
|
# @rbs (String, ?Integer) -> Array[String]
|
|
263
385
|
def search(string, max_nearest = 3)
|
|
264
|
-
|
|
386
|
+
synchronize do
|
|
387
|
+
return [] if needs_rebuild_unlocked?
|
|
265
388
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
389
|
+
carry = proximity_norms_for_content_unlocked(string)
|
|
390
|
+
result = carry.collect { |x| x[0] }
|
|
391
|
+
result[0..(max_nearest - 1)]
|
|
392
|
+
end
|
|
269
393
|
end
|
|
270
394
|
|
|
271
395
|
# This function takes content and finds other documents
|
|
@@ -280,10 +404,12 @@ module Classifier
|
|
|
280
404
|
#
|
|
281
405
|
# @rbs (String, ?Integer) ?{ (String) -> String } -> Array[String]
|
|
282
406
|
def find_related(doc, max_nearest = 3, &block)
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
407
|
+
synchronize do
|
|
408
|
+
carry =
|
|
409
|
+
proximity_array_for_content_unlocked(doc, &block).reject { |pair| pair[0] == doc }
|
|
410
|
+
result = carry.collect { |x| x[0] }
|
|
411
|
+
result[0..(max_nearest - 1)]
|
|
412
|
+
end
|
|
287
413
|
end
|
|
288
414
|
|
|
289
415
|
# This function uses a voting system to categorize documents, based on
|
|
@@ -291,32 +417,23 @@ module Classifier
|
|
|
291
417
|
# find_related function to find related documents, then returns the
|
|
292
418
|
# most obvious category from this list.
|
|
293
419
|
#
|
|
294
|
-
# cutoff signifies the number of documents to consider when clasifying
|
|
295
|
-
# text. A cutoff of 1 means that every document in the index votes on
|
|
296
|
-
# what category the document is in. This may not always make sense.
|
|
297
|
-
#
|
|
298
420
|
# @rbs (String, ?Float) ?{ (String) -> String } -> String | Symbol
|
|
299
|
-
def classify(doc, cutoff = 0.30, &)
|
|
300
|
-
|
|
421
|
+
def classify(doc, cutoff = 0.30, &block)
|
|
422
|
+
validate_cutoff!(cutoff)
|
|
423
|
+
|
|
424
|
+
synchronize do
|
|
425
|
+
votes = vote_unlocked(doc, cutoff, &block)
|
|
301
426
|
|
|
302
|
-
|
|
303
|
-
|
|
427
|
+
ranking = votes.keys.sort_by { |x| votes[x] }
|
|
428
|
+
ranking[-1]
|
|
429
|
+
end
|
|
304
430
|
end
|
|
305
431
|
|
|
306
432
|
# @rbs (String, ?Float) ?{ (String) -> String } -> Hash[String | Symbol, Float]
|
|
307
|
-
def vote(doc, cutoff = 0.30, &)
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
votes = {}
|
|
312
|
-
carry.each do |pair|
|
|
313
|
-
categories = @items[pair[0]].categories
|
|
314
|
-
categories.each do |category|
|
|
315
|
-
votes[category] ||= 0.0
|
|
316
|
-
votes[category] += pair[1]
|
|
317
|
-
end
|
|
318
|
-
end
|
|
319
|
-
votes
|
|
433
|
+
def vote(doc, cutoff = 0.30, &block)
|
|
434
|
+
validate_cutoff!(cutoff)
|
|
435
|
+
|
|
436
|
+
synchronize { vote_unlocked(doc, cutoff, &block) }
|
|
320
437
|
end
|
|
321
438
|
|
|
322
439
|
# Returns the same category as classify() but also returns
|
|
@@ -331,15 +448,19 @@ module Classifier
|
|
|
331
448
|
#
|
|
332
449
|
# See classify() for argument docs
|
|
333
450
|
# @rbs (String, ?Float) ?{ (String) -> String } -> [String | Symbol | nil, Float?]
|
|
334
|
-
def classify_with_confidence(doc, cutoff = 0.30, &)
|
|
335
|
-
|
|
336
|
-
votes_sum = votes.values.sum
|
|
337
|
-
return [nil, nil] if votes_sum.zero?
|
|
451
|
+
def classify_with_confidence(doc, cutoff = 0.30, &block)
|
|
452
|
+
validate_cutoff!(cutoff)
|
|
338
453
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
454
|
+
synchronize do
|
|
455
|
+
votes = vote_unlocked(doc, cutoff, &block)
|
|
456
|
+
votes_sum = votes.values.sum
|
|
457
|
+
return [nil, nil] if votes_sum.zero?
|
|
458
|
+
|
|
459
|
+
ranking = votes.keys.sort_by { |x| votes[x] }
|
|
460
|
+
winner = ranking[-1]
|
|
461
|
+
vote_share = votes[winner] / votes_sum.to_f
|
|
462
|
+
[winner, vote_share]
|
|
463
|
+
end
|
|
343
464
|
end
|
|
344
465
|
|
|
345
466
|
# Prototype, only works on indexed documents.
|
|
@@ -347,45 +468,446 @@ module Classifier
|
|
|
347
468
|
# it's supposed to.
|
|
348
469
|
# @rbs (String, ?Integer) -> Array[Symbol]
|
|
349
470
|
def highest_ranked_stems(doc, count = 3)
|
|
350
|
-
|
|
471
|
+
synchronize do
|
|
472
|
+
raise 'Requested stem ranking on non-indexed content!' unless @items[doc]
|
|
351
473
|
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
474
|
+
arr = node_for_content_unlocked(doc).lsi_vector.to_a
|
|
475
|
+
top_n = arr.sort.reverse[0..(count - 1)]
|
|
476
|
+
top_n.collect { |x| @word_list.word_for_index(arr.index(x)) }
|
|
477
|
+
end
|
|
355
478
|
end
|
|
356
479
|
|
|
357
|
-
|
|
480
|
+
# Custom marshal serialization to exclude mutex state
|
|
481
|
+
# @rbs () -> Array[untyped]
|
|
482
|
+
def marshal_dump
|
|
483
|
+
[@auto_rebuild, @word_list, @items, @version, @built_at_version, @dirty]
|
|
484
|
+
end
|
|
358
485
|
|
|
359
|
-
#
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
486
|
+
# Custom marshal deserialization to recreate mutex
|
|
487
|
+
# @rbs (Array[untyped]) -> void
|
|
488
|
+
def marshal_load(data)
|
|
489
|
+
mu_initialize
|
|
490
|
+
@auto_rebuild, @word_list, @items, @version, @built_at_version, @dirty = data
|
|
491
|
+
@storage = nil
|
|
492
|
+
end
|
|
363
493
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
494
|
+
# Returns a hash representation of the LSI index.
|
|
495
|
+
# Only source data (word_hash, categories) is included, not computed vectors.
|
|
496
|
+
# This can be converted to JSON or used directly.
|
|
497
|
+
#
|
|
498
|
+
# @rbs () -> untyped
|
|
499
|
+
def as_json(*)
|
|
500
|
+
items_data = @items.transform_values do |node|
|
|
501
|
+
{
|
|
502
|
+
word_hash: node.word_hash.transform_keys(&:to_s),
|
|
503
|
+
categories: node.categories.map(&:to_s)
|
|
504
|
+
}
|
|
505
|
+
end
|
|
506
|
+
|
|
507
|
+
{
|
|
508
|
+
version: 1,
|
|
509
|
+
type: 'lsi',
|
|
510
|
+
auto_rebuild: @auto_rebuild,
|
|
511
|
+
items: items_data
|
|
512
|
+
}
|
|
513
|
+
end
|
|
514
|
+
|
|
515
|
+
# Serializes the LSI index to a JSON string.
|
|
516
|
+
# Only source data (word_hash, categories) is serialized, not computed vectors.
|
|
517
|
+
# On load, the index will be rebuilt automatically.
|
|
518
|
+
#
|
|
519
|
+
# @rbs () -> String
|
|
520
|
+
def to_json(*)
|
|
521
|
+
as_json.to_json
|
|
522
|
+
end
|
|
523
|
+
|
|
524
|
+
# Loads an LSI index from a JSON string or Hash created by #to_json or #as_json.
|
|
525
|
+
# The index will be rebuilt after loading.
|
|
526
|
+
#
|
|
527
|
+
# @rbs (String | Hash[String, untyped]) -> LSI
|
|
528
|
+
def self.from_json(json)
|
|
529
|
+
data = json.is_a?(String) ? JSON.parse(json) : json
|
|
530
|
+
raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'lsi'
|
|
531
|
+
|
|
532
|
+
# Create instance with auto_rebuild disabled during loading
|
|
533
|
+
instance = new(auto_rebuild: false)
|
|
534
|
+
|
|
535
|
+
# Restore items (categories stay as strings, matching original storage)
|
|
536
|
+
data['items'].each do |item_key, item_data|
|
|
537
|
+
word_hash = item_data['word_hash'].transform_keys(&:to_sym)
|
|
538
|
+
categories = item_data['categories']
|
|
539
|
+
instance.instance_variable_get(:@items)[item_key] = ContentNode.new(word_hash, *categories)
|
|
540
|
+
instance.instance_variable_set(:@version, instance.instance_variable_get(:@version) + 1)
|
|
368
541
|
end
|
|
369
|
-
# Reconstruct the term document matrix, only with reduced rank
|
|
370
|
-
result = u * (self.class.gsl_available ? GSL::Matrix : ::Matrix).diag(s) * v.trans
|
|
371
542
|
|
|
372
|
-
#
|
|
373
|
-
|
|
374
|
-
|
|
543
|
+
# Restore auto_rebuild setting and rebuild index
|
|
544
|
+
instance.auto_rebuild = data['auto_rebuild']
|
|
545
|
+
instance.build_index
|
|
546
|
+
instance
|
|
547
|
+
end
|
|
548
|
+
|
|
549
|
+
# Saves the LSI index to the configured storage.
|
|
550
|
+
# Raises ArgumentError if no storage is configured.
|
|
551
|
+
#
|
|
552
|
+
# @rbs () -> void
|
|
553
|
+
def save
|
|
554
|
+
raise ArgumentError, 'No storage configured. Use save_to_file(path) or set storage=' unless storage
|
|
555
|
+
|
|
556
|
+
storage.write(to_json)
|
|
557
|
+
@dirty = false
|
|
558
|
+
end
|
|
375
559
|
|
|
560
|
+
# Saves the LSI index to a file (legacy API).
|
|
561
|
+
#
|
|
562
|
+
# @rbs (String) -> Integer
|
|
563
|
+
def save_to_file(path)
|
|
564
|
+
result = File.write(path, to_json)
|
|
565
|
+
@dirty = false
|
|
376
566
|
result
|
|
377
567
|
end
|
|
378
568
|
|
|
569
|
+
# Reloads the LSI index from the configured storage.
|
|
570
|
+
# Raises UnsavedChangesError if there are unsaved changes.
|
|
571
|
+
# Use reload! to force reload and discard changes.
|
|
572
|
+
#
|
|
573
|
+
# @rbs () -> self
|
|
574
|
+
def reload
|
|
575
|
+
raise ArgumentError, 'No storage configured' unless storage
|
|
576
|
+
raise UnsavedChangesError, 'Unsaved changes would be lost. Call save first or use reload!' if @dirty
|
|
577
|
+
|
|
578
|
+
data = storage.read
|
|
579
|
+
raise StorageError, 'No saved state found' unless data
|
|
580
|
+
|
|
581
|
+
restore_from_json(data)
|
|
582
|
+
@dirty = false
|
|
583
|
+
self
|
|
584
|
+
end
|
|
585
|
+
|
|
586
|
+
# Force reloads the LSI index from storage, discarding any unsaved changes.
|
|
587
|
+
#
|
|
588
|
+
# @rbs () -> self
|
|
589
|
+
def reload!
|
|
590
|
+
raise ArgumentError, 'No storage configured' unless storage
|
|
591
|
+
|
|
592
|
+
data = storage.read
|
|
593
|
+
raise StorageError, 'No saved state found' unless data
|
|
594
|
+
|
|
595
|
+
restore_from_json(data)
|
|
596
|
+
@dirty = false
|
|
597
|
+
self
|
|
598
|
+
end
|
|
599
|
+
|
|
600
|
+
# Returns true if there are unsaved changes.
|
|
601
|
+
#
|
|
602
|
+
# @rbs () -> bool
|
|
603
|
+
def dirty?
|
|
604
|
+
@dirty
|
|
605
|
+
end
|
|
606
|
+
|
|
607
|
+
# Loads an LSI index from the configured storage.
|
|
608
|
+
# The storage is set on the returned instance.
|
|
609
|
+
#
|
|
610
|
+
# @rbs (storage: Storage::Base) -> LSI
|
|
611
|
+
def self.load(storage:)
|
|
612
|
+
data = storage.read
|
|
613
|
+
raise StorageError, 'No saved state found' unless data
|
|
614
|
+
|
|
615
|
+
instance = from_json(data)
|
|
616
|
+
instance.storage = storage
|
|
617
|
+
instance
|
|
618
|
+
end
|
|
619
|
+
|
|
620
|
+
# Loads an LSI index from a file (legacy API).
|
|
621
|
+
#
|
|
622
|
+
# @rbs (String) -> LSI
|
|
623
|
+
def self.load_from_file(path)
|
|
624
|
+
from_json(File.read(path))
|
|
625
|
+
end
|
|
626
|
+
|
|
627
|
+
# Loads an LSI index from a checkpoint.
|
|
628
|
+
#
|
|
629
|
+
# @rbs (storage: Storage::Base, checkpoint_id: String) -> LSI
|
|
630
|
+
def self.load_checkpoint(storage:, checkpoint_id:)
|
|
631
|
+
raise ArgumentError, 'Storage must be File storage for checkpoints' unless storage.is_a?(Storage::File)
|
|
632
|
+
|
|
633
|
+
dir = File.dirname(storage.path)
|
|
634
|
+
base = File.basename(storage.path, '.*')
|
|
635
|
+
ext = File.extname(storage.path)
|
|
636
|
+
checkpoint_path = File.join(dir, "#{base}_checkpoint_#{checkpoint_id}#{ext}")
|
|
637
|
+
|
|
638
|
+
checkpoint_storage = Storage::File.new(path: checkpoint_path)
|
|
639
|
+
instance = load(storage: checkpoint_storage)
|
|
640
|
+
instance.storage = storage
|
|
641
|
+
instance
|
|
642
|
+
end
|
|
643
|
+
|
|
644
|
+
# Trains the LSI index from an IO stream.
|
|
645
|
+
# Each line in the stream is treated as a separate document.
|
|
646
|
+
# Documents are added without rebuilding, then the index is rebuilt at the end.
|
|
647
|
+
#
|
|
648
|
+
# @example Train from a file
|
|
649
|
+
# lsi.train_from_stream(:category, File.open('corpus.txt'))
|
|
650
|
+
#
|
|
651
|
+
# @example With progress tracking
|
|
652
|
+
# lsi.train_from_stream(:category, io, batch_size: 500) do |progress|
|
|
653
|
+
# puts "#{progress.completed} documents processed"
|
|
654
|
+
# end
|
|
655
|
+
#
|
|
656
|
+
# @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
|
|
657
|
+
def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE)
|
|
658
|
+
original_auto_rebuild = @auto_rebuild
|
|
659
|
+
@auto_rebuild = false
|
|
660
|
+
|
|
661
|
+
begin
|
|
662
|
+
reader = Streaming::LineReader.new(io, batch_size: batch_size)
|
|
663
|
+
total = reader.estimate_line_count
|
|
664
|
+
progress = Streaming::Progress.new(total: total)
|
|
665
|
+
|
|
666
|
+
reader.each_batch do |batch|
|
|
667
|
+
batch.each { |text| add_item(text, category) }
|
|
668
|
+
progress.completed += batch.size
|
|
669
|
+
progress.current_batch += 1
|
|
670
|
+
yield progress if block_given?
|
|
671
|
+
end
|
|
672
|
+
ensure
|
|
673
|
+
@auto_rebuild = original_auto_rebuild
|
|
674
|
+
build_index if original_auto_rebuild
|
|
675
|
+
end
|
|
676
|
+
end
|
|
677
|
+
|
|
678
|
+
# Adds items to the index in batches from an array.
|
|
679
|
+
# Documents are added without rebuilding, then the index is rebuilt at the end.
|
|
680
|
+
#
|
|
681
|
+
# @example Batch add with progress
|
|
682
|
+
# lsi.add_batch(Dog: documents, batch_size: 100) do |progress|
|
|
683
|
+
# puts "#{progress.percent}% complete"
|
|
684
|
+
# end
|
|
685
|
+
#
|
|
686
|
+
# @rbs (?batch_size: Integer, **Array[String]) { (Streaming::Progress) -> void } -> void
|
|
687
|
+
def add_batch(batch_size: Streaming::DEFAULT_BATCH_SIZE, **items)
|
|
688
|
+
original_auto_rebuild = @auto_rebuild
|
|
689
|
+
@auto_rebuild = false
|
|
690
|
+
|
|
691
|
+
begin
|
|
692
|
+
total_docs = items.values.sum { |v| Array(v).size }
|
|
693
|
+
progress = Streaming::Progress.new(total: total_docs)
|
|
694
|
+
|
|
695
|
+
items.each do |category, documents|
|
|
696
|
+
Array(documents).each_slice(batch_size) do |batch|
|
|
697
|
+
batch.each { |doc| add_item(doc, category.to_s) }
|
|
698
|
+
progress.completed += batch.size
|
|
699
|
+
progress.current_batch += 1
|
|
700
|
+
yield progress if block_given?
|
|
701
|
+
end
|
|
702
|
+
end
|
|
703
|
+
ensure
|
|
704
|
+
@auto_rebuild = original_auto_rebuild
|
|
705
|
+
build_index if original_auto_rebuild
|
|
706
|
+
end
|
|
707
|
+
end
|
|
708
|
+
|
|
709
|
+
# Alias train_batch to add_batch for API consistency with other classifiers.
|
|
710
|
+
# Note: LSI uses categories differently (items have categories, not the training call).
|
|
711
|
+
#
|
|
712
|
+
# @rbs (?(String | Symbol)?, ?Array[String]?, ?batch_size: Integer, **Array[String]) { (Streaming::Progress) -> void } -> void
|
|
713
|
+
def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block)
|
|
714
|
+
if category && documents
|
|
715
|
+
add_batch(batch_size: batch_size, **{ category.to_sym => documents }, &block)
|
|
716
|
+
else
|
|
717
|
+
add_batch(batch_size: batch_size, **categories, &block)
|
|
718
|
+
end
|
|
719
|
+
end
|
|
720
|
+
|
|
721
|
+
private
|
|
722
|
+
|
|
723
|
+
# Restores LSI state from a JSON string (used by reload)
|
|
724
|
+
# @rbs (String) -> void
|
|
725
|
+
def restore_from_json(json)
|
|
726
|
+
data = JSON.parse(json)
|
|
727
|
+
raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'lsi'
|
|
728
|
+
|
|
729
|
+
synchronize do
|
|
730
|
+
# Recreate the items
|
|
731
|
+
@items = {}
|
|
732
|
+
data['items'].each do |item_key, item_data|
|
|
733
|
+
word_hash = item_data['word_hash'].transform_keys(&:to_sym)
|
|
734
|
+
categories = item_data['categories']
|
|
735
|
+
@items[item_key] = ContentNode.new(word_hash, *categories)
|
|
736
|
+
end
|
|
737
|
+
|
|
738
|
+
# Restore settings
|
|
739
|
+
@auto_rebuild = data['auto_rebuild']
|
|
740
|
+
@version += 1
|
|
741
|
+
@built_at_version = -1
|
|
742
|
+
@word_list = WordList.new
|
|
743
|
+
@dirty = false
|
|
744
|
+
end
|
|
745
|
+
|
|
746
|
+
# Rebuild the index
|
|
747
|
+
build_index
|
|
748
|
+
end
|
|
749
|
+
|
|
750
|
+
# @rbs (Float) -> void
|
|
751
|
+
def validate_cutoff!(cutoff)
|
|
752
|
+
return if cutoff.positive? && cutoff < 1
|
|
753
|
+
|
|
754
|
+
raise ArgumentError, "cutoff must be between 0 and 1 (exclusive), got #{cutoff}"
|
|
755
|
+
end
|
|
756
|
+
|
|
757
|
+
# Assigns LSI vectors using native C extension
|
|
758
|
+
# @rbs (untyped, Array[ContentNode]) -> void
|
|
759
|
+
def assign_native_ext_lsi_vectors(ntdm, doc_list)
|
|
760
|
+
ntdm.size[1].times do |col|
|
|
761
|
+
vec = self.class.vector_class.alloc(ntdm.column(col).to_a).row
|
|
762
|
+
doc_list[col].lsi_vector = vec
|
|
763
|
+
doc_list[col].lsi_norm = vec.normalize
|
|
764
|
+
end
|
|
765
|
+
end
|
|
766
|
+
|
|
767
|
+
# Assigns LSI vectors using pure Ruby Matrix
|
|
768
|
+
# @rbs (untyped, Array[ContentNode]) -> void
|
|
769
|
+
def assign_ruby_lsi_vectors(ntdm, doc_list)
|
|
770
|
+
ntdm.column_size.times do |col|
|
|
771
|
+
next unless doc_list[col]
|
|
772
|
+
|
|
773
|
+
column = ntdm.column(col)
|
|
774
|
+
next unless column
|
|
775
|
+
|
|
776
|
+
doc_list[col].lsi_vector = column
|
|
777
|
+
doc_list[col].lsi_norm = column.normalize
|
|
778
|
+
end
|
|
779
|
+
end
|
|
780
|
+
|
|
781
|
+
# Unlocked version of needs_rebuild? for internal use when lock is already held
|
|
782
|
+
# @rbs () -> bool
|
|
783
|
+
def needs_rebuild_unlocked?
|
|
784
|
+
(@items.keys.size > 1) && (@version != @built_at_version)
|
|
785
|
+
end
|
|
786
|
+
|
|
787
|
+
# Unlocked version of proximity_array_for_content for internal use
|
|
788
|
+
# @rbs (String) ?{ (String) -> String } -> Array[[String, Float]]
|
|
789
|
+
def proximity_array_for_content_unlocked(doc, &)
|
|
790
|
+
return [] if needs_rebuild_unlocked?
|
|
791
|
+
return @items.keys.map { |item| [item, 1.0] } if @items.size == 1
|
|
792
|
+
|
|
793
|
+
content_node = node_for_content_unlocked(doc, &)
|
|
794
|
+
result =
|
|
795
|
+
@items.keys.collect do |item|
|
|
796
|
+
val = if self.class.native_available?
|
|
797
|
+
content_node.search_vector * @items[item].search_vector.col
|
|
798
|
+
else
|
|
799
|
+
(Matrix[content_node.search_vector] * @items[item].search_vector)[0]
|
|
800
|
+
end
|
|
801
|
+
[item, val]
|
|
802
|
+
end
|
|
803
|
+
result.sort_by { |x| x[1] }.reverse
|
|
804
|
+
end
|
|
805
|
+
|
|
806
|
+
# Unlocked version of proximity_norms_for_content for internal use
|
|
807
|
+
# @rbs (String) ?{ (String) -> String } -> Array[[String, Float]]
|
|
808
|
+
def proximity_norms_for_content_unlocked(doc, &)
|
|
809
|
+
return [] if needs_rebuild_unlocked?
|
|
810
|
+
|
|
811
|
+
content_node = node_for_content_unlocked(doc, &)
|
|
812
|
+
result =
|
|
813
|
+
@items.keys.collect do |item|
|
|
814
|
+
val = if self.class.native_available?
|
|
815
|
+
content_node.search_norm * @items[item].search_norm.col
|
|
816
|
+
else
|
|
817
|
+
(Matrix[content_node.search_norm] * @items[item].search_norm)[0]
|
|
818
|
+
end
|
|
819
|
+
[item, val]
|
|
820
|
+
end
|
|
821
|
+
result.sort_by { |x| x[1] }.reverse
|
|
822
|
+
end
|
|
823
|
+
|
|
824
|
+
# Unlocked version of vote for internal use
|
|
825
|
+
# @rbs (String, ?Float) ?{ (String) -> String } -> Hash[String | Symbol, Float]
|
|
826
|
+
def vote_unlocked(doc, cutoff = 0.30, &)
|
|
827
|
+
icutoff = (@items.size * cutoff).round
|
|
828
|
+
carry = proximity_array_for_content_unlocked(doc, &)
|
|
829
|
+
carry = carry[0..(icutoff - 1)]
|
|
830
|
+
votes = {}
|
|
831
|
+
carry.each do |pair|
|
|
832
|
+
categories = @items[pair[0]].categories
|
|
833
|
+
categories.each do |category|
|
|
834
|
+
votes[category] ||= 0.0
|
|
835
|
+
votes[category] += pair[1]
|
|
836
|
+
end
|
|
837
|
+
end
|
|
838
|
+
votes
|
|
839
|
+
end
|
|
840
|
+
|
|
841
|
+
# Unlocked version of node_for_content for internal use.
|
|
379
842
|
# @rbs (String) ?{ (String) -> String } -> ContentNode
|
|
380
|
-
def
|
|
843
|
+
def node_for_content_unlocked(item, &block)
|
|
381
844
|
return @items[item] if @items[item]
|
|
382
845
|
|
|
383
846
|
clean_word_hash = block ? block.call(item).clean_word_hash : item.to_s.clean_word_hash
|
|
384
847
|
cn = ContentNode.new(clean_word_hash, &block)
|
|
385
|
-
cn.raw_vector_with(@word_list) unless
|
|
848
|
+
cn.raw_vector_with(@word_list) unless needs_rebuild_unlocked?
|
|
849
|
+
assign_lsi_vector_incremental(cn) if incremental_enabled?
|
|
386
850
|
cn
|
|
387
851
|
end
|
|
388
852
|
|
|
853
|
+
# @rbs (untyped, ?Float) -> untyped
|
|
854
|
+
def build_reduced_matrix(matrix, cutoff = 0.75)
|
|
855
|
+
result, _u = build_reduced_matrix_with_u(matrix, cutoff)
|
|
856
|
+
result
|
|
857
|
+
end
|
|
858
|
+
|
|
859
|
+
# Builds reduced matrix and returns both the result and the U matrix.
|
|
860
|
+
# U matrix is needed for incremental SVD updates.
|
|
861
|
+
# @rbs (untyped, ?Float) -> [untyped, Matrix]
|
|
862
|
+
def build_reduced_matrix_with_u(matrix, cutoff = 0.75)
|
|
863
|
+
u, v, s = matrix.SV_decomp
|
|
864
|
+
|
|
865
|
+
all_singular_values = s.sort.reverse
|
|
866
|
+
s_cutoff_index = [(s.size * cutoff).round - 1, 0].max
|
|
867
|
+
s_cutoff = all_singular_values[s_cutoff_index]
|
|
868
|
+
|
|
869
|
+
kept_indices = []
|
|
870
|
+
kept_singular_values = []
|
|
871
|
+
s.size.times do |ord|
|
|
872
|
+
if s[ord] >= s_cutoff
|
|
873
|
+
kept_indices << ord
|
|
874
|
+
kept_singular_values << s[ord]
|
|
875
|
+
else
|
|
876
|
+
s[ord] = 0.0
|
|
877
|
+
end
|
|
878
|
+
end
|
|
879
|
+
|
|
880
|
+
@singular_values = kept_singular_values.sort.reverse
|
|
881
|
+
result = u * self.class.matrix_class.diag(s) * v.trans
|
|
882
|
+
result = result.trans if result.row_size != matrix.row_size
|
|
883
|
+
u_reduced = extract_reduced_u(u, kept_indices, s)
|
|
884
|
+
|
|
885
|
+
[result, u_reduced]
|
|
886
|
+
end
|
|
887
|
+
|
|
888
|
+
# Extracts columns from U corresponding to kept singular values.
|
|
889
|
+
# Columns are sorted by descending singular value to match @singular_values order.
|
|
890
|
+
# rubocop:disable Naming/MethodParameterName
|
|
891
|
+
# @rbs (untyped, Array[Integer], Array[Float]) -> Matrix
|
|
892
|
+
def extract_reduced_u(u, kept_indices, singular_values)
|
|
893
|
+
return Matrix.empty(u.row_size, 0) if kept_indices.empty?
|
|
894
|
+
|
|
895
|
+
sorted_indices = kept_indices.sort_by { |i| -singular_values[i] }
|
|
896
|
+
|
|
897
|
+
if u.respond_to?(:to_ruby_matrix)
|
|
898
|
+
u = u.to_ruby_matrix
|
|
899
|
+
elsif !u.is_a?(::Matrix)
|
|
900
|
+
rows = u.row_size.times.map do |i|
|
|
901
|
+
sorted_indices.map { |j| u[i, j] }
|
|
902
|
+
end
|
|
903
|
+
return Matrix.rows(rows)
|
|
904
|
+
end
|
|
905
|
+
|
|
906
|
+
cols = sorted_indices.map { |i| u.column(i).to_a }
|
|
907
|
+
Matrix.columns(cols)
|
|
908
|
+
end
|
|
909
|
+
# rubocop:enable Naming/MethodParameterName
|
|
910
|
+
|
|
389
911
|
# @rbs () -> void
|
|
390
912
|
def make_word_list
|
|
391
913
|
@word_list = WordList.new
|
|
@@ -393,5 +915,129 @@ module Classifier
|
|
|
393
915
|
node.word_hash.each_key { |key| @word_list.add_word key }
|
|
394
916
|
end
|
|
395
917
|
end
|
|
918
|
+
|
|
919
|
+
# Performs incremental SVD update for a new document.
|
|
920
|
+
# @rbs (ContentNode, Hash[Symbol, Integer]) -> void
|
|
921
|
+
def perform_incremental_update(node, word_hash)
|
|
922
|
+
needs_full_rebuild = false
|
|
923
|
+
old_rank = nil
|
|
924
|
+
|
|
925
|
+
synchronize do
|
|
926
|
+
if vocabulary_growth_exceeds_threshold?(word_hash)
|
|
927
|
+
disable_incremental_mode!
|
|
928
|
+
needs_full_rebuild = true
|
|
929
|
+
next
|
|
930
|
+
end
|
|
931
|
+
|
|
932
|
+
old_rank = @u_matrix.column_size
|
|
933
|
+
extend_vocabulary_for_incremental(word_hash)
|
|
934
|
+
raw_vec = node.raw_vector_with(@word_list)
|
|
935
|
+
raw_vector = Vector[*raw_vec.to_a]
|
|
936
|
+
|
|
937
|
+
@u_matrix, @singular_values = IncrementalSVD.update(
|
|
938
|
+
@u_matrix, @singular_values, raw_vector, max_rank: @max_rank
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
new_rank = @u_matrix.column_size
|
|
942
|
+
if new_rank > old_rank
|
|
943
|
+
reproject_all_documents
|
|
944
|
+
else
|
|
945
|
+
assign_lsi_vector_incremental(node)
|
|
946
|
+
end
|
|
947
|
+
|
|
948
|
+
@built_at_version = @version
|
|
949
|
+
end
|
|
950
|
+
|
|
951
|
+
build_index if needs_full_rebuild
|
|
952
|
+
end
|
|
953
|
+
|
|
954
|
+
# Checks if vocabulary growth would exceed threshold (20%)
|
|
955
|
+
# @rbs (Hash[Symbol, Integer]) -> bool
|
|
956
|
+
def vocabulary_growth_exceeds_threshold?(word_hash)
|
|
957
|
+
return false unless @initial_vocab_size&.positive?
|
|
958
|
+
|
|
959
|
+
new_words = word_hash.keys.count { |w| @word_list[w].nil? }
|
|
960
|
+
growth_ratio = new_words.to_f / @initial_vocab_size
|
|
961
|
+
growth_ratio > 0.2
|
|
962
|
+
end
|
|
963
|
+
|
|
964
|
+
# Extends vocabulary and U matrix for new words.
|
|
965
|
+
# @rbs (Hash[Symbol, Integer]) -> void
|
|
966
|
+
def extend_vocabulary_for_incremental(word_hash)
|
|
967
|
+
new_words = word_hash.keys.select { |w| @word_list[w].nil? }
|
|
968
|
+
return if new_words.empty?
|
|
969
|
+
|
|
970
|
+
new_words.each { |word| @word_list.add_word(word) }
|
|
971
|
+
extend_u_matrix(new_words.size)
|
|
972
|
+
end
|
|
973
|
+
|
|
974
|
+
# Extends U matrix with zero rows for new vocabulary terms.
|
|
975
|
+
# @rbs (Integer) -> void
|
|
976
|
+
def extend_u_matrix(num_new_rows)
|
|
977
|
+
return if num_new_rows.zero? || @u_matrix.nil?
|
|
978
|
+
|
|
979
|
+
if self.class.native_available? && @u_matrix.is_a?(self.class.matrix_class)
|
|
980
|
+
new_rows = self.class.matrix_class.zeros(num_new_rows, @u_matrix.column_size)
|
|
981
|
+
@u_matrix = self.class.matrix_class.vstack(@u_matrix, new_rows)
|
|
982
|
+
else
|
|
983
|
+
new_rows = Matrix.zero(num_new_rows, @u_matrix.column_size)
|
|
984
|
+
@u_matrix = Matrix.vstack(@u_matrix, new_rows)
|
|
985
|
+
end
|
|
986
|
+
end
|
|
987
|
+
|
|
988
|
+
# Re-projects all documents onto the current U matrix
|
|
989
|
+
# Called when rank grows to ensure consistent LSI vector sizes
|
|
990
|
+
# Uses native batch_project for performance when available
|
|
991
|
+
# @rbs () -> void
|
|
992
|
+
def reproject_all_documents
|
|
993
|
+
return unless @u_matrix
|
|
994
|
+
return reproject_all_documents_native if self.class.native_available? && @u_matrix.respond_to?(:batch_project)
|
|
995
|
+
|
|
996
|
+
reproject_all_documents_ruby
|
|
997
|
+
end
|
|
998
|
+
|
|
999
|
+
# Native batch re-projection using C extension.
|
|
1000
|
+
# @rbs () -> void
|
|
1001
|
+
def reproject_all_documents_native
|
|
1002
|
+
nodes = @items.values
|
|
1003
|
+
raw_vectors = nodes.map do |node|
|
|
1004
|
+
raw = node.raw_vector_with(@word_list)
|
|
1005
|
+
raw.is_a?(self.class.vector_class) ? raw : self.class.vector_class.alloc(raw.to_a)
|
|
1006
|
+
end
|
|
1007
|
+
|
|
1008
|
+
lsi_vectors = @u_matrix.batch_project(raw_vectors)
|
|
1009
|
+
|
|
1010
|
+
nodes.each_with_index do |node, i|
|
|
1011
|
+
lsi_vec = lsi_vectors[i].row
|
|
1012
|
+
node.lsi_vector = lsi_vec
|
|
1013
|
+
node.lsi_norm = lsi_vec.normalize
|
|
1014
|
+
end
|
|
1015
|
+
end
|
|
1016
|
+
|
|
1017
|
+
# Pure Ruby re-projection (fallback)
|
|
1018
|
+
# @rbs () -> void
|
|
1019
|
+
def reproject_all_documents_ruby
|
|
1020
|
+
@items.each_value do |node|
|
|
1021
|
+
assign_lsi_vector_incremental(node)
|
|
1022
|
+
end
|
|
1023
|
+
end
|
|
1024
|
+
|
|
1025
|
+
# Assigns LSI vector to a node using projection: lsi_vec = U^T * raw_vec.
|
|
1026
|
+
# @rbs (ContentNode) -> void
|
|
1027
|
+
def assign_lsi_vector_incremental(node)
|
|
1028
|
+
return unless @u_matrix
|
|
1029
|
+
|
|
1030
|
+
raw_vec = node.raw_vector_with(@word_list)
|
|
1031
|
+
raw_vector = Vector[*raw_vec.to_a]
|
|
1032
|
+
lsi_arr = (@u_matrix.transpose * raw_vector).to_a
|
|
1033
|
+
|
|
1034
|
+
lsi_vec = if self.class.native_available?
|
|
1035
|
+
self.class.vector_class.alloc(lsi_arr).row
|
|
1036
|
+
else
|
|
1037
|
+
Vector[*lsi_arr]
|
|
1038
|
+
end
|
|
1039
|
+
node.lsi_vector = lsi_vec
|
|
1040
|
+
node.lsi_norm = lsi_vec.normalize
|
|
1041
|
+
end
|
|
396
1042
|
end
|
|
397
1043
|
end
|