classifier 2.1.0 → 2.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,393 @@
1
+ /*
2
+ * incremental_svd.c
3
+ * Native C implementation of Brand's incremental SVD operations
4
+ *
5
+ * Provides fast matrix operations for:
6
+ * - Matrix column extension
7
+ * - Vertical stacking (vstack)
8
+ * - Vector subtraction
9
+ * - Batch document projection
10
+ */
11
+
12
+ #include "linalg.h"
13
+
14
+ /*
15
+ * Extend a matrix with a new column
16
+ * Returns a new matrix [M | col] with one additional column
17
+ */
18
+ CMatrix *cmatrix_extend_column(CMatrix *m, CVector *col)
19
+ {
20
+ if (m->rows != col->size) {
21
+ rb_raise(rb_eArgError,
22
+ "Matrix rows (%ld) must match vector size (%ld)",
23
+ (long)m->rows, (long)col->size);
24
+ }
25
+
26
+ CMatrix *result = cmatrix_alloc(m->rows, m->cols + 1);
27
+
28
+ for (size_t i = 0; i < m->rows; i++) {
29
+ memcpy(&MAT_AT(result, i, 0), &MAT_AT(m, i, 0), m->cols * sizeof(double));
30
+ MAT_AT(result, i, m->cols) = col->data[i];
31
+ }
32
+
33
+ return result;
34
+ }
35
+
36
+ /*
37
+ * Vertically stack two matrices
38
+ * Returns a new matrix [top; bottom]
39
+ */
40
+ CMatrix *cmatrix_vstack(CMatrix *top, CMatrix *bottom)
41
+ {
42
+ if (top->cols != bottom->cols) {
43
+ rb_raise(rb_eArgError,
44
+ "Matrices must have same column count: %ld vs %ld",
45
+ (long)top->cols, (long)bottom->cols);
46
+ }
47
+
48
+ size_t new_rows = top->rows + bottom->rows;
49
+ CMatrix *result = cmatrix_alloc(new_rows, top->cols);
50
+
51
+ memcpy(result->data, top->data, top->rows * top->cols * sizeof(double));
52
+ memcpy(result->data + top->rows * top->cols,
53
+ bottom->data,
54
+ bottom->rows * bottom->cols * sizeof(double));
55
+
56
+ return result;
57
+ }
58
+
59
+ /*
60
+ * Vector subtraction: a - b
61
+ */
62
+ CVector *cvector_subtract(CVector *a, CVector *b)
63
+ {
64
+ if (a->size != b->size) {
65
+ rb_raise(rb_eArgError,
66
+ "Vector sizes must match: %ld vs %ld",
67
+ (long)a->size, (long)b->size);
68
+ }
69
+
70
+ CVector *result = cvector_alloc(a->size);
71
+ for (size_t i = 0; i < a->size; i++) {
72
+ result->data[i] = a->data[i] - b->data[i];
73
+ }
74
+ return result;
75
+ }
76
+
77
+ /*
78
+ * Batch project multiple vectors onto U matrix
79
+ * Computes lsi_vector = U^T * raw_vector for each vector
80
+ * This is the most performance-critical operation for incremental updates
81
+ */
82
+ void cbatch_project(CMatrix *u, CVector **raw_vectors, size_t num_vectors,
83
+ CVector **lsi_vectors_out)
84
+ {
85
+ size_t m = u->rows; /* vocabulary size */
86
+ size_t k = u->cols; /* rank */
87
+
88
+ for (size_t v = 0; v < num_vectors; v++) {
89
+ CVector *raw = raw_vectors[v];
90
+ if (raw->size != m) {
91
+ rb_raise(rb_eArgError,
92
+ "Vector %ld size (%ld) must match matrix rows (%ld)",
93
+ (long)v, (long)raw->size, (long)m);
94
+ }
95
+
96
+ CVector *lsi = cvector_alloc(k);
97
+
98
+ /* Compute U^T * raw (project onto k-dimensional space) */
99
+ for (size_t j = 0; j < k; j++) {
100
+ double sum = 0.0;
101
+ for (size_t i = 0; i < m; i++) {
102
+ sum += MAT_AT(u, i, j) * raw->data[i];
103
+ }
104
+ lsi->data[j] = sum;
105
+ }
106
+
107
+ lsi_vectors_out[v] = lsi;
108
+ }
109
+ }
110
+
111
+ /*
112
+ * Build the K matrix for Brand's algorithm when rank grows
113
+ * K = | diag(s) m_vec |
114
+ * | 0 p_norm |
115
+ */
116
+ static CMatrix *build_k_matrix_with_growth(CVector *s, CVector *m_vec, double p_norm)
117
+ {
118
+ size_t k = s->size;
119
+ CMatrix *result = cmatrix_alloc(k + 1, k + 1);
120
+
121
+ /* First k rows: diagonal s values and m_vec in last column */
122
+ for (size_t i = 0; i < k; i++) {
123
+ MAT_AT(result, i, i) = s->data[i];
124
+ MAT_AT(result, i, k) = m_vec->data[i];
125
+ }
126
+
127
+ /* Last row: zeros except p_norm in last position */
128
+ MAT_AT(result, k, k) = p_norm;
129
+
130
+ return result;
131
+ }
132
+
133
+ /*
134
+ * Perform one incremental SVD update using Brand's algorithm
135
+ *
136
+ * @param u Current U matrix (m x k)
137
+ * @param s Current singular values (k values)
138
+ * @param c New document vector (m x 1)
139
+ * @param max_rank Maximum rank to maintain
140
+ * @param epsilon Threshold for detecting new directions
141
+ * @param u_out Output: updated U matrix
142
+ * @param s_out Output: updated singular values
143
+ */
144
+ static void incremental_update(CMatrix *u, CVector *s, CVector *c, int max_rank,
145
+ double epsilon, CMatrix **u_out, CVector **s_out)
146
+ {
147
+ size_t m = u->rows;
148
+ size_t k = u->cols;
149
+
150
+ /* Step 1: Project c onto column space of U */
151
+ /* m_vec = U^T * c */
152
+ CVector *m_vec = cvector_alloc(k);
153
+ for (size_t j = 0; j < k; j++) {
154
+ double sum = 0.0;
155
+ for (size_t i = 0; i < m; i++) {
156
+ sum += MAT_AT(u, i, j) * c->data[i];
157
+ }
158
+ m_vec->data[j] = sum;
159
+ }
160
+
161
+ /* Step 2: Compute residual p = c - U * m_vec */
162
+ CVector *u_times_m = cmatrix_multiply_vector(u, m_vec);
163
+ CVector *p = cvector_subtract(c, u_times_m);
164
+ double p_norm = cvector_magnitude(p);
165
+
166
+ cvector_free(u_times_m);
167
+
168
+ if (p_norm > epsilon) {
169
+ /* New direction found - rank may increase */
170
+
171
+ /* Step 3: Normalize residual */
172
+ CVector *p_hat = cvector_alloc(m);
173
+ double inv_p_norm = 1.0 / p_norm;
174
+ for (size_t i = 0; i < m; i++) {
175
+ p_hat->data[i] = p->data[i] * inv_p_norm;
176
+ }
177
+
178
+ /* Step 4: Build K matrix */
179
+ CMatrix *k_mat = build_k_matrix_with_growth(s, m_vec, p_norm);
180
+
181
+ /* Step 5: SVD of K matrix */
182
+ CMatrix *u_prime, *v_prime;
183
+ CVector *s_prime;
184
+ jacobi_svd(k_mat, &u_prime, &v_prime, &s_prime);
185
+ cmatrix_free(k_mat);
186
+ cmatrix_free(v_prime);
187
+
188
+ /* Step 6: Update U = [U | p_hat] * U' */
189
+ CMatrix *u_extended = cmatrix_extend_column(u, p_hat);
190
+ CMatrix *u_new = cmatrix_multiply(u_extended, u_prime);
191
+ cmatrix_free(u_extended);
192
+ cmatrix_free(u_prime);
193
+ cvector_free(p_hat);
194
+
195
+ /* Truncate if needed */
196
+ if (s_prime->size > (size_t)max_rank) {
197
+ /* Create truncated U (keep first max_rank columns) */
198
+ CMatrix *u_trunc = cmatrix_alloc(u_new->rows, (size_t)max_rank);
199
+ for (size_t i = 0; i < u_new->rows; i++) {
200
+ memcpy(&MAT_AT(u_trunc, i, 0), &MAT_AT(u_new, i, 0),
201
+ (size_t)max_rank * sizeof(double));
202
+ }
203
+ cmatrix_free(u_new);
204
+ u_new = u_trunc;
205
+
206
+ /* Truncate singular values */
207
+ CVector *s_trunc = cvector_alloc((size_t)max_rank);
208
+ memcpy(s_trunc->data, s_prime->data, (size_t)max_rank * sizeof(double));
209
+ cvector_free(s_prime);
210
+ s_prime = s_trunc;
211
+ }
212
+
213
+ *u_out = u_new;
214
+ *s_out = s_prime;
215
+ } else {
216
+ /* Vector in span - use simpler update */
217
+ /* For now, just return unchanged (projection handles this) */
218
+ *u_out = cmatrix_alloc(u->rows, u->cols);
219
+ memcpy((*u_out)->data, u->data, u->rows * u->cols * sizeof(double));
220
+ *s_out = cvector_alloc(s->size);
221
+ memcpy((*s_out)->data, s->data, s->size * sizeof(double));
222
+ }
223
+
224
+ cvector_free(p);
225
+ cvector_free(m_vec);
226
+ }
227
+
228
+ /* ========== Ruby Wrappers ========== */
229
+
230
+ /*
231
+ * Matrix.extend_column(matrix, vector)
232
+ * Returns [matrix | vector]
233
+ */
234
+ static VALUE rb_cmatrix_extend_column(VALUE klass, VALUE rb_matrix, VALUE rb_vector)
235
+ {
236
+ CMatrix *m;
237
+ CVector *v;
238
+
239
+ GET_CMATRIX(rb_matrix, m);
240
+ GET_CVECTOR(rb_vector, v);
241
+
242
+ CMatrix *result = cmatrix_extend_column(m, v);
243
+ return TypedData_Wrap_Struct(klass, &cmatrix_type, result);
244
+
245
+ (void)klass;
246
+ }
247
+
248
+ /*
249
+ * Matrix.vstack(top, bottom)
250
+ * Vertically stack two matrices
251
+ */
252
+ static VALUE rb_cmatrix_vstack(VALUE klass, VALUE rb_top, VALUE rb_bottom)
253
+ {
254
+ CMatrix *top, *bottom;
255
+
256
+ GET_CMATRIX(rb_top, top);
257
+ GET_CMATRIX(rb_bottom, bottom);
258
+
259
+ CMatrix *result = cmatrix_vstack(top, bottom);
260
+ return TypedData_Wrap_Struct(klass, &cmatrix_type, result);
261
+
262
+ (void)klass;
263
+ }
264
+
265
+ /*
266
+ * Matrix.zeros(rows, cols)
267
+ * Create a zero matrix
268
+ */
269
+ static VALUE rb_cmatrix_zeros(VALUE klass, VALUE rb_rows, VALUE rb_cols)
270
+ {
271
+ size_t rows = NUM2SIZET(rb_rows);
272
+ size_t cols = NUM2SIZET(rb_cols);
273
+
274
+ CMatrix *result = cmatrix_alloc(rows, cols);
275
+ return TypedData_Wrap_Struct(klass, &cmatrix_type, result);
276
+
277
+ (void)klass;
278
+ }
279
+
280
+ /*
281
+ * Vector#-(other)
282
+ * Vector subtraction
283
+ */
284
+ static VALUE rb_cvector_subtract(VALUE self, VALUE other)
285
+ {
286
+ CVector *a, *b;
287
+
288
+ GET_CVECTOR(self, a);
289
+
290
+ if (rb_obj_is_kind_of(other, cClassifierVector)) {
291
+ GET_CVECTOR(other, b);
292
+ CVector *result = cvector_subtract(a, b);
293
+ return TypedData_Wrap_Struct(cClassifierVector, &cvector_type, result);
294
+ }
295
+
296
+ rb_raise(rb_eTypeError, "Cannot subtract %s from Vector",
297
+ rb_obj_classname(other));
298
+ return Qnil;
299
+ }
300
+
301
+ /*
302
+ * Matrix#batch_project(vectors_array)
303
+ * Project multiple vectors onto this matrix (as U)
304
+ * Returns array of projected vectors
305
+ *
306
+ * This is the high-performance batch operation for re-projecting documents
307
+ */
308
+ static VALUE rb_cmatrix_batch_project(VALUE self, VALUE rb_vectors)
309
+ {
310
+ CMatrix *u;
311
+ GET_CMATRIX(self, u);
312
+
313
+ Check_Type(rb_vectors, T_ARRAY);
314
+ long num_vectors = RARRAY_LEN(rb_vectors);
315
+
316
+ if (num_vectors == 0) {
317
+ return rb_ary_new();
318
+ }
319
+
320
+ CVector **raw_vectors = ALLOC_N(CVector *, num_vectors);
321
+ for (long i = 0; i < num_vectors; i++) {
322
+ VALUE rb_vec = rb_ary_entry(rb_vectors, i);
323
+ if (!rb_obj_is_kind_of(rb_vec, cClassifierVector)) {
324
+ xfree(raw_vectors);
325
+ rb_raise(rb_eTypeError, "Expected array of Vectors");
326
+ }
327
+ GET_CVECTOR(rb_vec, raw_vectors[i]);
328
+ }
329
+
330
+ CVector **lsi_vectors = ALLOC_N(CVector *, num_vectors);
331
+ cbatch_project(u, raw_vectors, (size_t)num_vectors, lsi_vectors);
332
+
333
+ VALUE result = rb_ary_new_capa(num_vectors);
334
+ for (long i = 0; i < num_vectors; i++) {
335
+ VALUE rb_lsi = TypedData_Wrap_Struct(cClassifierVector, &cvector_type,
336
+ lsi_vectors[i]);
337
+ rb_ary_push(result, rb_lsi);
338
+ }
339
+
340
+ xfree(raw_vectors);
341
+ xfree(lsi_vectors);
342
+
343
+ return result;
344
+ }
345
+
346
+ /*
347
+ * Matrix#incremental_svd_update(singular_values, new_vector, max_rank, epsilon)
348
+ * Perform one Brand's incremental SVD update
349
+ * Returns [new_u, new_singular_values]
350
+ */
351
+ static VALUE rb_cmatrix_incremental_update(VALUE self, VALUE rb_s, VALUE rb_c,
352
+ VALUE rb_max_rank, VALUE rb_epsilon)
353
+ {
354
+ CMatrix *u;
355
+ CVector *s, *c;
356
+
357
+ GET_CMATRIX(self, u);
358
+ GET_CVECTOR(rb_s, s);
359
+ GET_CVECTOR(rb_c, c);
360
+
361
+ int max_rank = NUM2INT(rb_max_rank);
362
+ double epsilon = NUM2DBL(rb_epsilon);
363
+
364
+ CMatrix *u_new;
365
+ CVector *s_new;
366
+
367
+ incremental_update(u, s, c, max_rank, epsilon, &u_new, &s_new);
368
+
369
+ VALUE rb_u_new = TypedData_Wrap_Struct(cClassifierMatrix, &cmatrix_type, u_new);
370
+ VALUE rb_s_new = TypedData_Wrap_Struct(cClassifierVector, &cvector_type, s_new);
371
+
372
+ return rb_ary_new_from_args(2, rb_u_new, rb_s_new);
373
+ }
374
+
375
+ void Init_incremental_svd(void)
376
+ {
377
+ /* Matrix class methods for incremental SVD */
378
+ rb_define_singleton_method(cClassifierMatrix, "extend_column",
379
+ rb_cmatrix_extend_column, 2);
380
+ rb_define_singleton_method(cClassifierMatrix, "vstack",
381
+ rb_cmatrix_vstack, 2);
382
+ rb_define_singleton_method(cClassifierMatrix, "zeros",
383
+ rb_cmatrix_zeros, 2);
384
+
385
+ /* Instance methods */
386
+ rb_define_method(cClassifierMatrix, "batch_project",
387
+ rb_cmatrix_batch_project, 1);
388
+ rb_define_method(cClassifierMatrix, "incremental_svd_update",
389
+ rb_cmatrix_incremental_update, 4);
390
+
391
+ /* Vector subtraction */
392
+ rb_define_method(cClassifierVector, "-", rb_cvector_subtract, 1);
393
+ }
@@ -50,6 +50,14 @@ CMatrix *cmatrix_diagonal(CVector *v);
50
50
  void Init_svd(void);
51
51
  void jacobi_svd(CMatrix *a, CMatrix **u, CMatrix **v, CVector **s);
52
52
 
53
+ /* Incremental SVD functions */
54
+ void Init_incremental_svd(void);
55
+ CMatrix *cmatrix_extend_column(CMatrix *m, CVector *col);
56
+ CMatrix *cmatrix_vstack(CMatrix *top, CMatrix *bottom);
57
+ CVector *cvector_subtract(CVector *a, CVector *b);
58
+ void cbatch_project(CMatrix *u, CVector **raw_vectors, size_t num_vectors,
59
+ CVector **lsi_vectors_out);
60
+
53
61
  /* TypedData definitions */
54
62
  extern const rb_data_type_t cvector_type;
55
63
  extern const rb_data_type_t cmatrix_type;
@@ -8,8 +8,9 @@ require 'json'
8
8
  require 'mutex_m'
9
9
 
10
10
  module Classifier
11
- class Bayes
11
+ class Bayes # rubocop:disable Metrics/ClassLength
12
12
  include Mutex_m
13
+ include Streaming
13
14
 
14
15
  # @rbs @categories: Hash[Symbol, Hash[Symbol, Integer]]
15
16
  # @rbs @total_words: Integer
@@ -25,11 +26,12 @@ module Classifier
25
26
  # The class can be created with one or more categories, each of which will be
26
27
  # initialized and given a training method. E.g.,
27
28
  # b = Classifier::Bayes.new 'Interesting', 'Uninteresting', 'Spam'
28
- # @rbs (*String | Symbol) -> void
29
+ # b = Classifier::Bayes.new ['Interesting', 'Uninteresting', 'Spam']
30
+ # @rbs (*String | Symbol | Array[String | Symbol]) -> void
29
31
  def initialize(*categories)
30
32
  super()
31
33
  @categories = {}
32
- categories.each { |category| @categories[category.prepare_category_name] = {} }
34
+ categories.flatten.each { |category| @categories[category.prepare_category_name] = {} }
33
35
  @total_words = 0
34
36
  @category_counts = Hash.new(0)
35
37
  @category_word_count = Hash.new(0)
@@ -39,59 +41,31 @@ module Classifier
39
41
  @storage = nil
40
42
  end
41
43
 
42
- # Provides a general training method for all categories specified in Bayes#new
43
- # For example:
44
- # b = Classifier::Bayes.new 'This', 'That', 'the_other'
45
- # b.train :this, "This text"
46
- # b.train "that", "That text"
47
- # b.train "The other", "The other text"
44
+ # Trains the classifier with text for a category.
48
45
  #
49
- # @rbs (String | Symbol, String) -> void
50
- def train(category, text)
51
- category = category.prepare_category_name
52
- word_hash = text.word_hash
53
- synchronize do
54
- invalidate_caches
55
- @dirty = true
56
- @category_counts[category] += 1
57
- word_hash.each do |word, count|
58
- @categories[category][word] ||= 0
59
- @categories[category][word] += count
60
- @total_words += count
61
- @category_word_count[category] += count
62
- end
46
+ # b.train(spam: "Buy now!", ham: ["Hello", "Meeting tomorrow"])
47
+ # b.train(:spam, "legacy positional API")
48
+ #
49
+ # @rbs (?(String | Symbol)?, ?String?, **(String | Array[String])) -> void
50
+ def train(category = nil, text = nil, **categories)
51
+ return train_single(category, text) if category && text
52
+
53
+ categories.each do |cat, texts|
54
+ (texts.is_a?(Array) ? texts : [texts]).each { |t| train_single(cat, t) }
63
55
  end
64
56
  end
65
57
 
66
- # Provides a untraining method for all categories specified in Bayes#new
67
- # Be very careful with this method.
58
+ # Removes training data. Be careful with this method.
68
59
  #
69
- # For example:
70
- # b = Classifier::Bayes.new 'This', 'That', 'the_other'
71
- # b.train :this, "This text"
72
- # b.untrain :this, "This text"
60
+ # b.untrain(spam: "Buy now!")
61
+ # b.untrain(:spam, "legacy positional API")
73
62
  #
74
- # @rbs (String | Symbol, String) -> void
75
- def untrain(category, text)
76
- category = category.prepare_category_name
77
- word_hash = text.word_hash
78
- synchronize do
79
- invalidate_caches
80
- @dirty = true
81
- @category_counts[category] -= 1
82
- word_hash.each do |word, count|
83
- next unless @total_words >= 0
63
+ # @rbs (?(String | Symbol)?, ?String?, **(String | Array[String])) -> void
64
+ def untrain(category = nil, text = nil, **categories)
65
+ return untrain_single(category, text) if category && text
84
66
 
85
- orig = @categories[category][word] || 0
86
- @categories[category][word] ||= 0
87
- @categories[category][word] -= count
88
- if @categories[category][word] <= 0
89
- @categories[category].delete(word)
90
- count = orig
91
- end
92
- @category_word_count[category] -= count if @category_word_count[category] >= count
93
- @total_words -= count
94
- end
67
+ categories.each do |cat, texts|
68
+ (texts.is_a?(Array) ? texts : [texts]).each { |t| untrain_single(cat, t) }
95
69
  end
96
70
  end
97
71
 
@@ -135,8 +109,8 @@ module Classifier
135
109
  # Returns a hash representation of the classifier state.
136
110
  # This can be converted to JSON or used directly.
137
111
  #
138
- # @rbs () -> untyped
139
- def as_json(*)
112
+ # @rbs (?untyped) -> untyped
113
+ def as_json(_options = nil)
140
114
  {
141
115
  version: 1,
142
116
  type: 'bayes',
@@ -150,8 +124,8 @@ module Classifier
150
124
  # Serializes the classifier state to a JSON string.
151
125
  # This can be saved to a file and later loaded with Bayes.from_json.
152
126
  #
153
- # @rbs () -> String
154
- def to_json(*)
127
+ # @rbs (?untyped) -> String
128
+ def to_json(_options = nil)
155
129
  as_json.to_json
156
130
  end
157
131
 
@@ -338,8 +312,158 @@ module Classifier
338
312
  end
339
313
  end
340
314
 
315
+ # Trains the classifier from an IO stream.
316
+ # Each line in the stream is treated as a separate document.
317
+ # This is memory-efficient for large corpora.
318
+ #
319
+ # @example Train from a file
320
+ # classifier.train_from_stream(:spam, File.open('spam_corpus.txt'))
321
+ #
322
+ # @example With progress tracking
323
+ # classifier.train_from_stream(:spam, io, batch_size: 500) do |progress|
324
+ # puts "#{progress.completed} documents processed"
325
+ # end
326
+ #
327
+ # @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
328
+ def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE)
329
+ category = category.prepare_category_name
330
+ raise StandardError, "No such category: #{category}" unless @categories.key?(category)
331
+
332
+ reader = Streaming::LineReader.new(io, batch_size: batch_size)
333
+ total = reader.estimate_line_count
334
+ progress = Streaming::Progress.new(total: total)
335
+
336
+ reader.each_batch do |batch|
337
+ train_batch_internal(category, batch)
338
+ progress.completed += batch.size
339
+ progress.current_batch += 1
340
+ yield progress if block_given?
341
+ end
342
+ end
343
+
344
+ # Trains the classifier with an array of documents in batches.
345
+ # Reduces lock contention by processing multiple documents per synchronize call.
346
+ #
347
+ # @example Positional style
348
+ # classifier.train_batch(:spam, documents, batch_size: 100)
349
+ #
350
+ # @example Keyword style
351
+ # classifier.train_batch(spam: documents, ham: other_docs, batch_size: 100)
352
+ #
353
+ # @example With progress tracking
354
+ # classifier.train_batch(:spam, documents, batch_size: 100) do |progress|
355
+ # puts "#{progress.percent}% complete"
356
+ # end
357
+ #
358
+ # @rbs (?(String | Symbol)?, ?Array[String]?, ?batch_size: Integer, **Array[String]) { (Streaming::Progress) -> void } -> void
359
+ def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block)
360
+ if category && documents
361
+ train_batch_for_category(category, documents, batch_size: batch_size, &block)
362
+ else
363
+ categories.each do |cat, docs|
364
+ train_batch_for_category(cat, Array(docs), batch_size: batch_size, &block)
365
+ end
366
+ end
367
+ end
368
+
369
+ # Loads a classifier from a checkpoint.
370
+ #
371
+ # @rbs (storage: Storage::Base, checkpoint_id: String) -> Bayes
372
+ def self.load_checkpoint(storage:, checkpoint_id:)
373
+ raise ArgumentError, 'Storage must be File storage for checkpoints' unless storage.is_a?(Storage::File)
374
+
375
+ dir = File.dirname(storage.path)
376
+ base = File.basename(storage.path, '.*')
377
+ ext = File.extname(storage.path)
378
+ checkpoint_path = File.join(dir, "#{base}_checkpoint_#{checkpoint_id}#{ext}")
379
+
380
+ checkpoint_storage = Storage::File.new(path: checkpoint_path)
381
+ instance = load(storage: checkpoint_storage)
382
+ instance.storage = storage
383
+ instance
384
+ end
385
+
341
386
  private
342
387
 
388
+ # Trains a batch of documents for a single category.
389
+ # @rbs (String | Symbol, Array[String], ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
390
+ def train_batch_for_category(category, documents, batch_size: Streaming::DEFAULT_BATCH_SIZE)
391
+ category = category.prepare_category_name
392
+ raise StandardError, "No such category: #{category}" unless @categories.key?(category)
393
+
394
+ progress = Streaming::Progress.new(total: documents.size)
395
+
396
+ documents.each_slice(batch_size) do |batch|
397
+ train_batch_internal(category, batch)
398
+ progress.completed += batch.size
399
+ progress.current_batch += 1
400
+ yield progress if block_given?
401
+ end
402
+ end
403
+
404
+ # Internal method to train a batch of documents.
405
+ # Uses a single synchronize block for the entire batch.
406
+ # @rbs (Symbol, Array[String]) -> void
407
+ def train_batch_internal(category, batch)
408
+ synchronize do
409
+ invalidate_caches
410
+ @dirty = true
411
+ batch.each do |text|
412
+ word_hash = text.word_hash
413
+ @category_counts[category] += 1
414
+ word_hash.each do |word, count|
415
+ @categories[category][word] ||= 0
416
+ @categories[category][word] += count
417
+ @total_words += count
418
+ @category_word_count[category] += count
419
+ end
420
+ end
421
+ end
422
+ end
423
+
424
+ # Core training logic for a single category and text.
425
+ # @rbs (String | Symbol, String) -> void
426
+ def train_single(category, text)
427
+ category = category.prepare_category_name
428
+ word_hash = text.word_hash
429
+ synchronize do
430
+ invalidate_caches
431
+ @dirty = true
432
+ @category_counts[category] += 1
433
+ word_hash.each do |word, count|
434
+ @categories[category][word] ||= 0
435
+ @categories[category][word] += count
436
+ @total_words += count
437
+ @category_word_count[category] += count
438
+ end
439
+ end
440
+ end
441
+
442
+ # Core untraining logic for a single category and text.
443
+ # @rbs (String | Symbol, String) -> void
444
+ def untrain_single(category, text)
445
+ category = category.prepare_category_name
446
+ word_hash = text.word_hash
447
+ synchronize do
448
+ invalidate_caches
449
+ @dirty = true
450
+ @category_counts[category] -= 1
451
+ word_hash.each do |word, count|
452
+ next unless @total_words >= 0
453
+
454
+ orig = @categories[category][word] || 0
455
+ @categories[category][word] ||= 0
456
+ @categories[category][word] -= count
457
+ if @categories[category][word] <= 0
458
+ @categories[category].delete(word)
459
+ count = orig
460
+ end
461
+ @category_word_count[category] -= count if @category_word_count[category] >= count
462
+ @total_words -= count
463
+ end
464
+ end
465
+ end
466
+
343
467
  # Restores classifier state from a JSON string (used by reload)
344
468
  # @rbs (String) -> void
345
469
  def restore_from_json(json)
@@ -13,4 +13,7 @@ module Classifier
13
13
 
14
14
  # Raised when a storage operation fails
15
15
  class StorageError < Error; end
16
+
17
+ # Raised when using an unfitted model
18
+ class NotFittedError < Error; end
16
19
  end