classifier 2.0.0 → 2.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,319 @@
1
+ /*
2
+ * vector.c
3
+ * Vector implementation for Classifier native linear algebra
4
+ */
5
+
6
+ #include "linalg.h"
7
+
8
+ const rb_data_type_t cvector_type = {
9
+ .wrap_struct_name = "Classifier::Linalg::Vector",
10
+ .function = {
11
+ .dmark = NULL,
12
+ .dfree = cvector_free,
13
+ .dsize = NULL,
14
+ },
15
+ .flags = RUBY_TYPED_FREE_IMMEDIATELY
16
+ };
17
+
18
+ /* Allocate a new CVector */
19
+ CVector *cvector_alloc(size_t size)
20
+ {
21
+ CVector *v = ALLOC(CVector);
22
+ v->size = size;
23
+ v->data = ALLOC_N(double, size);
24
+ v->is_col = 0; /* Default to row vector */
25
+ memset(v->data, 0, size * sizeof(double));
26
+ return v;
27
+ }
28
+
29
+ /* Free a CVector */
30
+ void cvector_free(void *ptr)
31
+ {
32
+ CVector *v = (CVector *)ptr;
33
+ if (v) {
34
+ if (v->data) xfree(v->data);
35
+ xfree(v);
36
+ }
37
+ }
38
+
39
+ /* Calculate magnitude (Euclidean norm) */
40
+ double cvector_magnitude(CVector *v)
41
+ {
42
+ double sum = 0.0;
43
+ for (size_t i = 0; i < v->size; i++) {
44
+ sum += v->data[i] * v->data[i];
45
+ }
46
+ return sqrt(sum);
47
+ }
48
+
49
+ /* Return normalized copy */
50
+ CVector *cvector_normalize(CVector *v)
51
+ {
52
+ CVector *result = cvector_alloc(v->size);
53
+ result->is_col = v->is_col;
54
+ double mag = cvector_magnitude(v);
55
+
56
+ if (mag <= CLASSIFIER_EPSILON) {
57
+ /* Return zero vector if magnitude is too small */
58
+ return result;
59
+ }
60
+
61
+ for (size_t i = 0; i < v->size; i++) {
62
+ result->data[i] = v->data[i] / mag;
63
+ }
64
+ return result;
65
+ }
66
+
67
+ /* Sum all elements */
68
+ double cvector_sum(CVector *v)
69
+ {
70
+ double sum = 0.0;
71
+ for (size_t i = 0; i < v->size; i++) {
72
+ sum += v->data[i];
73
+ }
74
+ return sum;
75
+ }
76
+
77
+ /* Dot product */
78
+ double cvector_dot(CVector *a, CVector *b)
79
+ {
80
+ if (a->size != b->size) {
81
+ rb_raise(rb_eArgError, "Vector sizes must match for dot product");
82
+ }
83
+ double sum = 0.0;
84
+ for (size_t i = 0; i < a->size; i++) {
85
+ sum += a->data[i] * b->data[i];
86
+ }
87
+ return sum;
88
+ }
89
+
90
+ /* Ruby allocation function */
91
+ static VALUE rb_cvector_alloc(VALUE klass)
92
+ {
93
+ CVector *v = cvector_alloc(0);
94
+ return TypedData_Wrap_Struct(klass, &cvector_type, v);
95
+ }
96
+
97
+ /*
98
+ * Vector.alloc(size_or_array)
99
+ * Create a new vector from size (zero-filled) or array of values
100
+ */
101
+ static VALUE rb_cvector_s_alloc(VALUE klass, VALUE arg)
102
+ {
103
+ CVector *v;
104
+ VALUE result;
105
+
106
+ if (RB_TYPE_P(arg, T_ARRAY)) {
107
+ long len = RARRAY_LEN(arg);
108
+ v = cvector_alloc((size_t)len);
109
+ for (long i = 0; i < len; i++) {
110
+ v->data[i] = NUM2DBL(rb_ary_entry(arg, i));
111
+ }
112
+ } else {
113
+ size_t size = NUM2SIZET(arg);
114
+ v = cvector_alloc(size);
115
+ }
116
+
117
+ result = TypedData_Wrap_Struct(klass, &cvector_type, v);
118
+ return result;
119
+ }
120
+
121
+ /* Vector#size */
122
+ static VALUE rb_cvector_size(VALUE self)
123
+ {
124
+ CVector *v;
125
+ GET_CVECTOR(self, v);
126
+ return SIZET2NUM(v->size);
127
+ }
128
+
129
+ /* Vector#[] */
130
+ static VALUE rb_cvector_aref(VALUE self, VALUE idx)
131
+ {
132
+ CVector *v;
133
+ GET_CVECTOR(self, v);
134
+ long i = NUM2LONG(idx);
135
+
136
+ if (i < 0) i += v->size;
137
+ if (i < 0 || (size_t)i >= v->size) {
138
+ rb_raise(rb_eIndexError, "index %ld out of bounds", i);
139
+ }
140
+
141
+ return DBL2NUM(v->data[i]);
142
+ }
143
+
144
+ /* Vector#[]= */
145
+ static VALUE rb_cvector_aset(VALUE self, VALUE idx, VALUE val)
146
+ {
147
+ CVector *v;
148
+ GET_CVECTOR(self, v);
149
+ long i = NUM2LONG(idx);
150
+
151
+ if (i < 0) i += v->size;
152
+ if (i < 0 || (size_t)i >= v->size) {
153
+ rb_raise(rb_eIndexError, "index %ld out of bounds", i);
154
+ }
155
+
156
+ v->data[i] = NUM2DBL(val);
157
+ return val;
158
+ }
159
+
160
+ /* Vector#to_a */
161
+ static VALUE rb_cvector_to_a(VALUE self)
162
+ {
163
+ CVector *v;
164
+ GET_CVECTOR(self, v);
165
+ VALUE ary = rb_ary_new_capa((long)v->size);
166
+
167
+ for (size_t i = 0; i < v->size; i++) {
168
+ rb_ary_push(ary, DBL2NUM(v->data[i]));
169
+ }
170
+ return ary;
171
+ }
172
+
173
+ /* Vector#sum */
174
+ static VALUE rb_cvector_sum(VALUE self)
175
+ {
176
+ CVector *v;
177
+ GET_CVECTOR(self, v);
178
+ return DBL2NUM(cvector_sum(v));
179
+ }
180
+
181
+ /* Vector#each */
182
+ static VALUE rb_cvector_each(VALUE self)
183
+ {
184
+ CVector *v;
185
+ GET_CVECTOR(self, v);
186
+
187
+ RETURN_ENUMERATOR(self, 0, 0);
188
+
189
+ for (size_t i = 0; i < v->size; i++) {
190
+ rb_yield(DBL2NUM(v->data[i]));
191
+ }
192
+ return self;
193
+ }
194
+
195
+ /* Vector#collect (map) */
196
+ static VALUE rb_cvector_collect(VALUE self)
197
+ {
198
+ CVector *v;
199
+ GET_CVECTOR(self, v);
200
+
201
+ RETURN_ENUMERATOR(self, 0, 0);
202
+
203
+ CVector *result = cvector_alloc(v->size);
204
+ result->is_col = v->is_col;
205
+
206
+ for (size_t i = 0; i < v->size; i++) {
207
+ VALUE val = rb_yield(DBL2NUM(v->data[i]));
208
+ result->data[i] = NUM2DBL(val);
209
+ }
210
+
211
+ return TypedData_Wrap_Struct(cClassifierVector, &cvector_type, result);
212
+ }
213
+
214
+ /* Vector#normalize */
215
+ static VALUE rb_cvector_normalize(VALUE self)
216
+ {
217
+ CVector *v;
218
+ GET_CVECTOR(self, v);
219
+ CVector *result = cvector_normalize(v);
220
+ return TypedData_Wrap_Struct(cClassifierVector, &cvector_type, result);
221
+ }
222
+
223
+ /* Vector#row - return self as row vector */
224
+ static VALUE rb_cvector_row(VALUE self)
225
+ {
226
+ CVector *v;
227
+ GET_CVECTOR(self, v);
228
+
229
+ CVector *result = cvector_alloc(v->size);
230
+ memcpy(result->data, v->data, v->size * sizeof(double));
231
+ result->is_col = 0;
232
+
233
+ return TypedData_Wrap_Struct(cClassifierVector, &cvector_type, result);
234
+ }
235
+
236
+ /* Vector#col - return self as column vector */
237
+ static VALUE rb_cvector_col(VALUE self)
238
+ {
239
+ CVector *v;
240
+ GET_CVECTOR(self, v);
241
+
242
+ CVector *result = cvector_alloc(v->size);
243
+ memcpy(result->data, v->data, v->size * sizeof(double));
244
+ result->is_col = 1;
245
+
246
+ return TypedData_Wrap_Struct(cClassifierVector, &cvector_type, result);
247
+ }
248
+
249
+ /* Vector#* - dot product with vector, or matrix multiplication */
250
+ static VALUE rb_cvector_mul(VALUE self, VALUE other)
251
+ {
252
+ CVector *v;
253
+ GET_CVECTOR(self, v);
254
+
255
+ if (rb_obj_is_kind_of(other, cClassifierVector)) {
256
+ CVector *w;
257
+ GET_CVECTOR(other, w);
258
+ return DBL2NUM(cvector_dot(v, w));
259
+ } else if (RB_TYPE_P(other, T_FLOAT) || RB_TYPE_P(other, T_FIXNUM)) {
260
+ /* Scalar multiplication */
261
+ double scalar = NUM2DBL(other);
262
+ CVector *result = cvector_alloc(v->size);
263
+ result->is_col = v->is_col;
264
+ for (size_t i = 0; i < v->size; i++) {
265
+ result->data[i] = v->data[i] * scalar;
266
+ }
267
+ return TypedData_Wrap_Struct(cClassifierVector, &cvector_type, result);
268
+ }
269
+
270
+ rb_raise(rb_eTypeError, "Cannot multiply Vector with %s", rb_obj_classname(other));
271
+ return Qnil;
272
+ }
273
+
274
+ /* Vector#_dump for Marshal */
275
+ static VALUE rb_cvector_dump(VALUE self, VALUE depth)
276
+ {
277
+ CVector *v;
278
+ GET_CVECTOR(self, v);
279
+ VALUE ary = rb_cvector_to_a(self);
280
+ rb_ary_push(ary, v->is_col ? Qtrue : Qfalse);
281
+ return rb_marshal_dump(ary, Qnil);
282
+ }
283
+
284
+ /* Vector._load for Marshal */
285
+ static VALUE rb_cvector_s_load(VALUE klass, VALUE str)
286
+ {
287
+ VALUE ary = rb_marshal_load(str);
288
+ VALUE is_col = rb_ary_pop(ary);
289
+ VALUE result = rb_cvector_s_alloc(klass, ary);
290
+ CVector *v;
291
+ GET_CVECTOR(result, v);
292
+ v->is_col = RTEST(is_col) ? 1 : 0;
293
+ return result;
294
+ }
295
+
296
+ void Init_vector(void)
297
+ {
298
+ cClassifierVector = rb_define_class_under(mClassifierLinalg, "Vector", rb_cObject);
299
+
300
+ rb_define_alloc_func(cClassifierVector, rb_cvector_alloc);
301
+ rb_define_singleton_method(cClassifierVector, "alloc", rb_cvector_s_alloc, 1);
302
+ rb_define_singleton_method(cClassifierVector, "_load", rb_cvector_s_load, 1);
303
+
304
+ rb_define_method(cClassifierVector, "size", rb_cvector_size, 0);
305
+ rb_define_method(cClassifierVector, "[]", rb_cvector_aref, 1);
306
+ rb_define_method(cClassifierVector, "[]=", rb_cvector_aset, 2);
307
+ rb_define_method(cClassifierVector, "to_a", rb_cvector_to_a, 0);
308
+ rb_define_method(cClassifierVector, "sum", rb_cvector_sum, 0);
309
+ rb_define_method(cClassifierVector, "each", rb_cvector_each, 0);
310
+ rb_define_method(cClassifierVector, "collect", rb_cvector_collect, 0);
311
+ rb_define_alias(cClassifierVector, "map", "collect");
312
+ rb_define_method(cClassifierVector, "normalize", rb_cvector_normalize, 0);
313
+ rb_define_method(cClassifierVector, "row", rb_cvector_row, 0);
314
+ rb_define_method(cClassifierVector, "col", rb_cvector_col, 0);
315
+ rb_define_method(cClassifierVector, "*", rb_cvector_mul, 1);
316
+ rb_define_method(cClassifierVector, "_dump", rb_cvector_dump, 1);
317
+
318
+ rb_include_module(cClassifierVector, rb_mEnumerable);
319
+ }
@@ -4,23 +4,39 @@
4
4
  # Copyright:: Copyright (c) 2005 Lucas Carlson
5
5
  # License:: LGPL
6
6
 
7
+ require 'json'
8
+ require 'mutex_m'
9
+
7
10
  module Classifier
8
11
  class Bayes
12
+ include Mutex_m
13
+
9
14
  # @rbs @categories: Hash[Symbol, Hash[Symbol, Integer]]
10
15
  # @rbs @total_words: Integer
11
16
  # @rbs @category_counts: Hash[Symbol, Integer]
12
17
  # @rbs @category_word_count: Hash[Symbol, Integer]
18
+ # @rbs @cached_training_count: Float?
19
+ # @rbs @cached_vocab_size: Integer?
20
+ # @rbs @dirty: bool
21
+ # @rbs @storage: Storage::Base?
22
+
23
+ attr_accessor :storage
13
24
 
14
25
  # The class can be created with one or more categories, each of which will be
15
26
  # initialized and given a training method. E.g.,
16
27
  # b = Classifier::Bayes.new 'Interesting', 'Uninteresting', 'Spam'
17
28
  # @rbs (*String | Symbol) -> void
18
29
  def initialize(*categories)
30
+ super()
19
31
  @categories = {}
20
32
  categories.each { |category| @categories[category.prepare_category_name] = {} }
21
33
  @total_words = 0
22
34
  @category_counts = Hash.new(0)
23
35
  @category_word_count = Hash.new(0)
36
+ @cached_training_count = nil
37
+ @cached_vocab_size = nil
38
+ @dirty = false
39
+ @storage = nil
24
40
  end
25
41
 
26
42
  # Provides a general training method for all categories specified in Bayes#new
@@ -33,12 +49,17 @@ module Classifier
33
49
  # @rbs (String | Symbol, String) -> void
34
50
  def train(category, text)
35
51
  category = category.prepare_category_name
36
- @category_counts[category] += 1
37
- text.word_hash.each do |word, count|
38
- @categories[category][word] ||= 0
39
- @categories[category][word] += count
40
- @total_words += count
41
- @category_word_count[category] += count
52
+ word_hash = text.word_hash
53
+ synchronize do
54
+ invalidate_caches
55
+ @dirty = true
56
+ @category_counts[category] += 1
57
+ word_hash.each do |word, count|
58
+ @categories[category][word] ||= 0
59
+ @categories[category][word] += count
60
+ @total_words += count
61
+ @category_word_count[category] += count
62
+ end
42
63
  end
43
64
  end
44
65
 
@@ -53,19 +74,24 @@ module Classifier
53
74
  # @rbs (String | Symbol, String) -> void
54
75
  def untrain(category, text)
55
76
  category = category.prepare_category_name
56
- @category_counts[category] -= 1
57
- text.word_hash.each do |word, count|
58
- next unless @total_words >= 0
59
-
60
- orig = @categories[category][word] || 0
61
- @categories[category][word] ||= 0
62
- @categories[category][word] -= count
63
- if @categories[category][word] <= 0
64
- @categories[category].delete(word)
65
- count = orig
77
+ word_hash = text.word_hash
78
+ synchronize do
79
+ invalidate_caches
80
+ @dirty = true
81
+ @category_counts[category] -= 1
82
+ word_hash.each do |word, count|
83
+ next unless @total_words >= 0
84
+
85
+ orig = @categories[category][word] || 0
86
+ @categories[category][word] ||= 0
87
+ @categories[category][word] -= count
88
+ if @categories[category][word] <= 0
89
+ @categories[category].delete(word)
90
+ count = orig
91
+ end
92
+ @category_word_count[category] -= count if @category_word_count[category] >= count
93
+ @total_words -= count
66
94
  end
67
- @category_word_count[category] -= count if @category_word_count[category] >= count
68
- @total_words -= count
69
95
  end
70
96
  end
71
97
 
@@ -77,17 +103,19 @@ module Classifier
77
103
  # @rbs (String) -> Hash[String, Float]
78
104
  def classifications(text)
79
105
  words = text.word_hash.keys
80
- training_count = @category_counts.values.sum.to_f
81
- vocab_size = [@categories.values.flat_map(&:keys).uniq.size, 1].max
106
+ synchronize do
107
+ training_count = cached_training_count
108
+ vocab_size = cached_vocab_size
82
109
 
83
- @categories.to_h do |category, category_words|
84
- smoothed_total = ((@category_word_count[category] || 0) + vocab_size).to_f
110
+ @categories.to_h do |category, category_words|
111
+ smoothed_total = ((@category_word_count[category] || 0) + vocab_size).to_f
85
112
 
86
- # Laplace smoothing: P(word|category) = (count + α) / (total + α * V)
87
- word_score = words.sum { |w| Math.log(((category_words[w] || 0) + 1) / smoothed_total) }
88
- prior_score = Math.log((@category_counts[category] || 0.1) / training_count)
113
+ # Laplace smoothing: P(word|category) = (count + α) / (total + α * V)
114
+ word_score = words.sum { |w| Math.log(((category_words[w] || 0) + 1) / smoothed_total) }
115
+ prior_score = Math.log((@category_counts[category] || 0.1) / training_count)
89
116
 
90
- [category.to_s, word_score + prior_score]
117
+ [category.to_s, word_score + prior_score]
118
+ end
91
119
  end
92
120
  end
93
121
 
@@ -104,6 +132,119 @@ module Classifier
104
132
  best.first.to_s
105
133
  end
106
134
 
135
+ # Returns a hash representation of the classifier state.
136
+ # This can be converted to JSON or used directly.
137
+ #
138
+ # @rbs () -> untyped
139
+ def as_json(*)
140
+ {
141
+ version: 1,
142
+ type: 'bayes',
143
+ categories: @categories.transform_keys(&:to_s).transform_values { |v| v.transform_keys(&:to_s) },
144
+ total_words: @total_words,
145
+ category_counts: @category_counts.transform_keys(&:to_s),
146
+ category_word_count: @category_word_count.transform_keys(&:to_s)
147
+ }
148
+ end
149
+
150
+ # Serializes the classifier state to a JSON string.
151
+ # This can be saved to a file and later loaded with Bayes.from_json.
152
+ #
153
+ # @rbs () -> String
154
+ def to_json(*)
155
+ as_json.to_json
156
+ end
157
+
158
+ # Loads a classifier from a JSON string or a Hash created by #to_json or #as_json.
159
+ #
160
+ # @rbs (String | Hash[String, untyped]) -> Bayes
161
+ def self.from_json(json)
162
+ data = json.is_a?(String) ? JSON.parse(json) : json
163
+ raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'bayes'
164
+
165
+ instance = allocate
166
+ instance.send(:restore_state, data)
167
+ instance
168
+ end
169
+
170
+ # Saves the classifier to the configured storage.
171
+ # Raises ArgumentError if no storage is configured.
172
+ #
173
+ # @rbs () -> void
174
+ def save
175
+ raise ArgumentError, 'No storage configured. Use save_to_file(path) or set storage=' unless storage
176
+
177
+ storage.write(to_json)
178
+ @dirty = false
179
+ end
180
+
181
+ # Saves the classifier state to a file (legacy API).
182
+ #
183
+ # @rbs (String) -> Integer
184
+ def save_to_file(path)
185
+ result = File.write(path, to_json)
186
+ @dirty = false
187
+ result
188
+ end
189
+
190
+ # Reloads the classifier from the configured storage.
191
+ # Raises UnsavedChangesError if there are unsaved changes.
192
+ # Use reload! to force reload and discard changes.
193
+ #
194
+ # @rbs () -> self
195
+ def reload
196
+ raise ArgumentError, 'No storage configured' unless storage
197
+ raise UnsavedChangesError, 'Unsaved changes would be lost. Call save first or use reload!' if @dirty
198
+
199
+ data = storage.read
200
+ raise StorageError, 'No saved state found' unless data
201
+
202
+ restore_from_json(data)
203
+ @dirty = false
204
+ self
205
+ end
206
+
207
+ # Force reloads the classifier from storage, discarding any unsaved changes.
208
+ #
209
+ # @rbs () -> self
210
+ def reload!
211
+ raise ArgumentError, 'No storage configured' unless storage
212
+
213
+ data = storage.read
214
+ raise StorageError, 'No saved state found' unless data
215
+
216
+ restore_from_json(data)
217
+ @dirty = false
218
+ self
219
+ end
220
+
221
+ # Returns true if there are unsaved changes.
222
+ #
223
+ # @rbs () -> bool
224
+ def dirty?
225
+ @dirty
226
+ end
227
+
228
+ # Loads a classifier from the configured storage.
229
+ # The storage is set on the returned instance.
230
+ #
231
+ # @rbs (storage: Storage::Base) -> Bayes
232
+ def self.load(storage:)
233
+ data = storage.read
234
+ raise StorageError, 'No saved state found' unless data
235
+
236
+ instance = from_json(data)
237
+ instance.storage = storage
238
+ instance
239
+ end
240
+
241
+ # Loads a classifier from a file (legacy API).
242
+ #
243
+ # @rbs (String) -> Bayes
244
+ def self.load_from_file(path)
245
+ from_json(File.read(path))
246
+ end
247
+
107
248
  #
108
249
  # Provides training and untraining methods for the categories specified in Bayes#new
109
250
  # For example:
@@ -134,7 +275,7 @@ module Classifier
134
275
  #
135
276
  # @rbs () -> Array[String]
136
277
  def categories
137
- @categories.keys.collect(&:to_s)
278
+ synchronize { @categories.keys.collect(&:to_s) }
138
279
  end
139
280
 
140
281
  # Allows you to add categories to the classifier.
@@ -148,11 +289,31 @@ module Classifier
148
289
  #
149
290
  # @rbs (String | Symbol) -> Hash[Symbol, Integer]
150
291
  def add_category(category)
151
- @categories[category.prepare_category_name] = {}
292
+ synchronize do
293
+ invalidate_caches
294
+ @dirty = true
295
+ @categories[category.prepare_category_name] = {}
296
+ end
152
297
  end
153
298
 
154
299
  alias append_category add_category
155
300
 
301
+ # Custom marshal serialization to exclude mutex state
302
+ # @rbs () -> Array[untyped]
303
+ def marshal_dump
304
+ [@categories, @total_words, @category_counts, @category_word_count, @dirty]
305
+ end
306
+
307
+ # Custom marshal deserialization to recreate mutex
308
+ # @rbs (Array[untyped]) -> void
309
+ def marshal_load(data)
310
+ mu_initialize
311
+ @categories, @total_words, @category_counts, @category_word_count, @dirty = data
312
+ @cached_training_count = nil
313
+ @cached_vocab_size = nil
314
+ @storage = nil
315
+ end
316
+
156
317
  # Allows you to remove categories from the classifier.
157
318
  # For example:
158
319
  # b.remove_category "Spam"
@@ -164,13 +325,72 @@ module Classifier
164
325
  # @rbs (String | Symbol) -> void
165
326
  def remove_category(category)
166
327
  category = category.prepare_category_name
167
- raise StandardError, "No such category: #{category}" unless @categories.key?(category)
328
+ synchronize do
329
+ raise StandardError, "No such category: #{category}" unless @categories.key?(category)
330
+
331
+ invalidate_caches
332
+ @dirty = true
333
+ @total_words -= @category_word_count[category].to_i
334
+
335
+ @categories.delete(category)
336
+ @category_counts.delete(category)
337
+ @category_word_count.delete(category)
338
+ end
339
+ end
340
+
341
+ private
342
+
343
+ # Restores classifier state from a JSON string (used by reload)
344
+ # @rbs (String) -> void
345
+ def restore_from_json(json)
346
+ data = JSON.parse(json)
347
+ raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'bayes'
348
+
349
+ synchronize do
350
+ restore_state(data)
351
+ end
352
+ end
353
+
354
+ # Restores classifier state from a hash (used by from_json)
355
+ # @rbs (Hash[String, untyped]) -> void
356
+ def restore_state(data)
357
+ mu_initialize
358
+ @categories = {} #: Hash[Symbol, Hash[Symbol, Integer]]
359
+ @total_words = data['total_words']
360
+ @category_counts = Hash.new(0) #: Hash[Symbol, Integer]
361
+ @category_word_count = Hash.new(0) #: Hash[Symbol, Integer]
362
+ @cached_training_count = nil
363
+ @cached_vocab_size = nil
364
+ @dirty = false
365
+ @storage = nil
366
+
367
+ data['categories'].each do |cat_name, words|
368
+ @categories[cat_name.to_sym] = words.transform_keys(&:to_sym)
369
+ end
168
370
 
169
- @total_words -= @category_word_count[category].to_i
371
+ data['category_counts'].each do |cat_name, count|
372
+ @category_counts[cat_name.to_sym] = count
373
+ end
374
+
375
+ data['category_word_count'].each do |cat_name, count|
376
+ @category_word_count[cat_name.to_sym] = count
377
+ end
378
+ end
379
+
380
+ # @rbs () -> void
381
+ def invalidate_caches
382
+ @cached_training_count = nil
383
+ @cached_vocab_size = nil
384
+ end
385
+
386
+ # @rbs () -> Float
387
+ def cached_training_count
388
+ @cached_training_count ||= @category_counts.values.sum.to_f
389
+ end
170
390
 
171
- @categories.delete(category)
172
- @category_counts.delete(category)
173
- @category_word_count.delete(category)
391
+ # @rbs () -> Integer
392
+ def cached_vocab_size
393
+ @cached_vocab_size ||= [@categories.values.flat_map(&:keys).uniq.size, 1].max
174
394
  end
175
395
  end
176
396
  end
@@ -0,0 +1,16 @@
1
+ # rbs_inline: enabled
2
+
3
+ # Author:: Lucas Carlson (mailto:lucas@rufy.com)
4
+ # Copyright:: Copyright (c) 2005 Lucas Carlson
5
+ # License:: LGPL
6
+
7
+ module Classifier
8
+ # Base error class for all Classifier errors
9
+ class Error < StandardError; end
10
+
11
+ # Raised when reload would discard unsaved changes
12
+ class UnsavedChangesError < Error; end
13
+
14
+ # Raised when a storage operation fails
15
+ class StorageError < Error; end
16
+ end