classifier 2.0.0 → 2.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CLAUDE.md +23 -13
- data/README.md +82 -67
- data/ext/classifier/classifier_ext.c +25 -0
- data/ext/classifier/extconf.rb +15 -0
- data/ext/classifier/linalg.h +64 -0
- data/ext/classifier/matrix.c +387 -0
- data/ext/classifier/svd.c +208 -0
- data/ext/classifier/vector.c +319 -0
- data/lib/classifier/bayes.rb +253 -33
- data/lib/classifier/errors.rb +16 -0
- data/lib/classifier/extensions/vector.rb +12 -4
- data/lib/classifier/lsi/content_node.rb +5 -5
- data/lib/classifier/lsi.rb +439 -141
- data/lib/classifier/storage/base.rb +50 -0
- data/lib/classifier/storage/file.rb +51 -0
- data/lib/classifier/storage/memory.rb +49 -0
- data/lib/classifier/storage.rb +9 -0
- data/lib/classifier.rb +2 -0
- data/sig/vendor/json.rbs +4 -0
- data/sig/vendor/mutex_m.rbs +16 -0
- data/test/test_helper.rb +2 -0
- metadata +36 -5
- data/lib/classifier/extensions/vector_serialize.rb +0 -18
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* vector.c
|
|
3
|
+
* Vector implementation for Classifier native linear algebra
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
#include "linalg.h"
|
|
7
|
+
|
|
8
|
+
const rb_data_type_t cvector_type = {
|
|
9
|
+
.wrap_struct_name = "Classifier::Linalg::Vector",
|
|
10
|
+
.function = {
|
|
11
|
+
.dmark = NULL,
|
|
12
|
+
.dfree = cvector_free,
|
|
13
|
+
.dsize = NULL,
|
|
14
|
+
},
|
|
15
|
+
.flags = RUBY_TYPED_FREE_IMMEDIATELY
|
|
16
|
+
};
|
|
17
|
+
|
|
18
|
+
/* Allocate a new CVector */
|
|
19
|
+
CVector *cvector_alloc(size_t size)
|
|
20
|
+
{
|
|
21
|
+
CVector *v = ALLOC(CVector);
|
|
22
|
+
v->size = size;
|
|
23
|
+
v->data = ALLOC_N(double, size);
|
|
24
|
+
v->is_col = 0; /* Default to row vector */
|
|
25
|
+
memset(v->data, 0, size * sizeof(double));
|
|
26
|
+
return v;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
/* Free a CVector */
|
|
30
|
+
void cvector_free(void *ptr)
|
|
31
|
+
{
|
|
32
|
+
CVector *v = (CVector *)ptr;
|
|
33
|
+
if (v) {
|
|
34
|
+
if (v->data) xfree(v->data);
|
|
35
|
+
xfree(v);
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
/* Calculate magnitude (Euclidean norm) */
|
|
40
|
+
double cvector_magnitude(CVector *v)
|
|
41
|
+
{
|
|
42
|
+
double sum = 0.0;
|
|
43
|
+
for (size_t i = 0; i < v->size; i++) {
|
|
44
|
+
sum += v->data[i] * v->data[i];
|
|
45
|
+
}
|
|
46
|
+
return sqrt(sum);
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
/* Return normalized copy */
|
|
50
|
+
CVector *cvector_normalize(CVector *v)
|
|
51
|
+
{
|
|
52
|
+
CVector *result = cvector_alloc(v->size);
|
|
53
|
+
result->is_col = v->is_col;
|
|
54
|
+
double mag = cvector_magnitude(v);
|
|
55
|
+
|
|
56
|
+
if (mag <= CLASSIFIER_EPSILON) {
|
|
57
|
+
/* Return zero vector if magnitude is too small */
|
|
58
|
+
return result;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
for (size_t i = 0; i < v->size; i++) {
|
|
62
|
+
result->data[i] = v->data[i] / mag;
|
|
63
|
+
}
|
|
64
|
+
return result;
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
/* Sum all elements */
|
|
68
|
+
double cvector_sum(CVector *v)
|
|
69
|
+
{
|
|
70
|
+
double sum = 0.0;
|
|
71
|
+
for (size_t i = 0; i < v->size; i++) {
|
|
72
|
+
sum += v->data[i];
|
|
73
|
+
}
|
|
74
|
+
return sum;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
/* Dot product */
|
|
78
|
+
double cvector_dot(CVector *a, CVector *b)
|
|
79
|
+
{
|
|
80
|
+
if (a->size != b->size) {
|
|
81
|
+
rb_raise(rb_eArgError, "Vector sizes must match for dot product");
|
|
82
|
+
}
|
|
83
|
+
double sum = 0.0;
|
|
84
|
+
for (size_t i = 0; i < a->size; i++) {
|
|
85
|
+
sum += a->data[i] * b->data[i];
|
|
86
|
+
}
|
|
87
|
+
return sum;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
/* Ruby allocation function */
|
|
91
|
+
static VALUE rb_cvector_alloc(VALUE klass)
|
|
92
|
+
{
|
|
93
|
+
CVector *v = cvector_alloc(0);
|
|
94
|
+
return TypedData_Wrap_Struct(klass, &cvector_type, v);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
/*
|
|
98
|
+
* Vector.alloc(size_or_array)
|
|
99
|
+
* Create a new vector from size (zero-filled) or array of values
|
|
100
|
+
*/
|
|
101
|
+
static VALUE rb_cvector_s_alloc(VALUE klass, VALUE arg)
|
|
102
|
+
{
|
|
103
|
+
CVector *v;
|
|
104
|
+
VALUE result;
|
|
105
|
+
|
|
106
|
+
if (RB_TYPE_P(arg, T_ARRAY)) {
|
|
107
|
+
long len = RARRAY_LEN(arg);
|
|
108
|
+
v = cvector_alloc((size_t)len);
|
|
109
|
+
for (long i = 0; i < len; i++) {
|
|
110
|
+
v->data[i] = NUM2DBL(rb_ary_entry(arg, i));
|
|
111
|
+
}
|
|
112
|
+
} else {
|
|
113
|
+
size_t size = NUM2SIZET(arg);
|
|
114
|
+
v = cvector_alloc(size);
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
result = TypedData_Wrap_Struct(klass, &cvector_type, v);
|
|
118
|
+
return result;
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
/* Vector#size */
|
|
122
|
+
static VALUE rb_cvector_size(VALUE self)
|
|
123
|
+
{
|
|
124
|
+
CVector *v;
|
|
125
|
+
GET_CVECTOR(self, v);
|
|
126
|
+
return SIZET2NUM(v->size);
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
/* Vector#[] */
|
|
130
|
+
static VALUE rb_cvector_aref(VALUE self, VALUE idx)
|
|
131
|
+
{
|
|
132
|
+
CVector *v;
|
|
133
|
+
GET_CVECTOR(self, v);
|
|
134
|
+
long i = NUM2LONG(idx);
|
|
135
|
+
|
|
136
|
+
if (i < 0) i += v->size;
|
|
137
|
+
if (i < 0 || (size_t)i >= v->size) {
|
|
138
|
+
rb_raise(rb_eIndexError, "index %ld out of bounds", i);
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
return DBL2NUM(v->data[i]);
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
/* Vector#[]= */
|
|
145
|
+
static VALUE rb_cvector_aset(VALUE self, VALUE idx, VALUE val)
|
|
146
|
+
{
|
|
147
|
+
CVector *v;
|
|
148
|
+
GET_CVECTOR(self, v);
|
|
149
|
+
long i = NUM2LONG(idx);
|
|
150
|
+
|
|
151
|
+
if (i < 0) i += v->size;
|
|
152
|
+
if (i < 0 || (size_t)i >= v->size) {
|
|
153
|
+
rb_raise(rb_eIndexError, "index %ld out of bounds", i);
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
v->data[i] = NUM2DBL(val);
|
|
157
|
+
return val;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
/* Vector#to_a */
|
|
161
|
+
static VALUE rb_cvector_to_a(VALUE self)
|
|
162
|
+
{
|
|
163
|
+
CVector *v;
|
|
164
|
+
GET_CVECTOR(self, v);
|
|
165
|
+
VALUE ary = rb_ary_new_capa((long)v->size);
|
|
166
|
+
|
|
167
|
+
for (size_t i = 0; i < v->size; i++) {
|
|
168
|
+
rb_ary_push(ary, DBL2NUM(v->data[i]));
|
|
169
|
+
}
|
|
170
|
+
return ary;
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
/* Vector#sum */
|
|
174
|
+
static VALUE rb_cvector_sum(VALUE self)
|
|
175
|
+
{
|
|
176
|
+
CVector *v;
|
|
177
|
+
GET_CVECTOR(self, v);
|
|
178
|
+
return DBL2NUM(cvector_sum(v));
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
/* Vector#each */
|
|
182
|
+
static VALUE rb_cvector_each(VALUE self)
|
|
183
|
+
{
|
|
184
|
+
CVector *v;
|
|
185
|
+
GET_CVECTOR(self, v);
|
|
186
|
+
|
|
187
|
+
RETURN_ENUMERATOR(self, 0, 0);
|
|
188
|
+
|
|
189
|
+
for (size_t i = 0; i < v->size; i++) {
|
|
190
|
+
rb_yield(DBL2NUM(v->data[i]));
|
|
191
|
+
}
|
|
192
|
+
return self;
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
/* Vector#collect (map) */
|
|
196
|
+
static VALUE rb_cvector_collect(VALUE self)
|
|
197
|
+
{
|
|
198
|
+
CVector *v;
|
|
199
|
+
GET_CVECTOR(self, v);
|
|
200
|
+
|
|
201
|
+
RETURN_ENUMERATOR(self, 0, 0);
|
|
202
|
+
|
|
203
|
+
CVector *result = cvector_alloc(v->size);
|
|
204
|
+
result->is_col = v->is_col;
|
|
205
|
+
|
|
206
|
+
for (size_t i = 0; i < v->size; i++) {
|
|
207
|
+
VALUE val = rb_yield(DBL2NUM(v->data[i]));
|
|
208
|
+
result->data[i] = NUM2DBL(val);
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
return TypedData_Wrap_Struct(cClassifierVector, &cvector_type, result);
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
/* Vector#normalize */
|
|
215
|
+
static VALUE rb_cvector_normalize(VALUE self)
|
|
216
|
+
{
|
|
217
|
+
CVector *v;
|
|
218
|
+
GET_CVECTOR(self, v);
|
|
219
|
+
CVector *result = cvector_normalize(v);
|
|
220
|
+
return TypedData_Wrap_Struct(cClassifierVector, &cvector_type, result);
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
/* Vector#row - return self as row vector */
|
|
224
|
+
static VALUE rb_cvector_row(VALUE self)
|
|
225
|
+
{
|
|
226
|
+
CVector *v;
|
|
227
|
+
GET_CVECTOR(self, v);
|
|
228
|
+
|
|
229
|
+
CVector *result = cvector_alloc(v->size);
|
|
230
|
+
memcpy(result->data, v->data, v->size * sizeof(double));
|
|
231
|
+
result->is_col = 0;
|
|
232
|
+
|
|
233
|
+
return TypedData_Wrap_Struct(cClassifierVector, &cvector_type, result);
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
/* Vector#col - return self as column vector */
|
|
237
|
+
static VALUE rb_cvector_col(VALUE self)
|
|
238
|
+
{
|
|
239
|
+
CVector *v;
|
|
240
|
+
GET_CVECTOR(self, v);
|
|
241
|
+
|
|
242
|
+
CVector *result = cvector_alloc(v->size);
|
|
243
|
+
memcpy(result->data, v->data, v->size * sizeof(double));
|
|
244
|
+
result->is_col = 1;
|
|
245
|
+
|
|
246
|
+
return TypedData_Wrap_Struct(cClassifierVector, &cvector_type, result);
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
/* Vector#* - dot product with vector, or matrix multiplication */
|
|
250
|
+
static VALUE rb_cvector_mul(VALUE self, VALUE other)
|
|
251
|
+
{
|
|
252
|
+
CVector *v;
|
|
253
|
+
GET_CVECTOR(self, v);
|
|
254
|
+
|
|
255
|
+
if (rb_obj_is_kind_of(other, cClassifierVector)) {
|
|
256
|
+
CVector *w;
|
|
257
|
+
GET_CVECTOR(other, w);
|
|
258
|
+
return DBL2NUM(cvector_dot(v, w));
|
|
259
|
+
} else if (RB_TYPE_P(other, T_FLOAT) || RB_TYPE_P(other, T_FIXNUM)) {
|
|
260
|
+
/* Scalar multiplication */
|
|
261
|
+
double scalar = NUM2DBL(other);
|
|
262
|
+
CVector *result = cvector_alloc(v->size);
|
|
263
|
+
result->is_col = v->is_col;
|
|
264
|
+
for (size_t i = 0; i < v->size; i++) {
|
|
265
|
+
result->data[i] = v->data[i] * scalar;
|
|
266
|
+
}
|
|
267
|
+
return TypedData_Wrap_Struct(cClassifierVector, &cvector_type, result);
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
rb_raise(rb_eTypeError, "Cannot multiply Vector with %s", rb_obj_classname(other));
|
|
271
|
+
return Qnil;
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
/* Vector#_dump for Marshal */
|
|
275
|
+
static VALUE rb_cvector_dump(VALUE self, VALUE depth)
|
|
276
|
+
{
|
|
277
|
+
CVector *v;
|
|
278
|
+
GET_CVECTOR(self, v);
|
|
279
|
+
VALUE ary = rb_cvector_to_a(self);
|
|
280
|
+
rb_ary_push(ary, v->is_col ? Qtrue : Qfalse);
|
|
281
|
+
return rb_marshal_dump(ary, Qnil);
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
/* Vector._load for Marshal */
|
|
285
|
+
static VALUE rb_cvector_s_load(VALUE klass, VALUE str)
|
|
286
|
+
{
|
|
287
|
+
VALUE ary = rb_marshal_load(str);
|
|
288
|
+
VALUE is_col = rb_ary_pop(ary);
|
|
289
|
+
VALUE result = rb_cvector_s_alloc(klass, ary);
|
|
290
|
+
CVector *v;
|
|
291
|
+
GET_CVECTOR(result, v);
|
|
292
|
+
v->is_col = RTEST(is_col) ? 1 : 0;
|
|
293
|
+
return result;
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
void Init_vector(void)
|
|
297
|
+
{
|
|
298
|
+
cClassifierVector = rb_define_class_under(mClassifierLinalg, "Vector", rb_cObject);
|
|
299
|
+
|
|
300
|
+
rb_define_alloc_func(cClassifierVector, rb_cvector_alloc);
|
|
301
|
+
rb_define_singleton_method(cClassifierVector, "alloc", rb_cvector_s_alloc, 1);
|
|
302
|
+
rb_define_singleton_method(cClassifierVector, "_load", rb_cvector_s_load, 1);
|
|
303
|
+
|
|
304
|
+
rb_define_method(cClassifierVector, "size", rb_cvector_size, 0);
|
|
305
|
+
rb_define_method(cClassifierVector, "[]", rb_cvector_aref, 1);
|
|
306
|
+
rb_define_method(cClassifierVector, "[]=", rb_cvector_aset, 2);
|
|
307
|
+
rb_define_method(cClassifierVector, "to_a", rb_cvector_to_a, 0);
|
|
308
|
+
rb_define_method(cClassifierVector, "sum", rb_cvector_sum, 0);
|
|
309
|
+
rb_define_method(cClassifierVector, "each", rb_cvector_each, 0);
|
|
310
|
+
rb_define_method(cClassifierVector, "collect", rb_cvector_collect, 0);
|
|
311
|
+
rb_define_alias(cClassifierVector, "map", "collect");
|
|
312
|
+
rb_define_method(cClassifierVector, "normalize", rb_cvector_normalize, 0);
|
|
313
|
+
rb_define_method(cClassifierVector, "row", rb_cvector_row, 0);
|
|
314
|
+
rb_define_method(cClassifierVector, "col", rb_cvector_col, 0);
|
|
315
|
+
rb_define_method(cClassifierVector, "*", rb_cvector_mul, 1);
|
|
316
|
+
rb_define_method(cClassifierVector, "_dump", rb_cvector_dump, 1);
|
|
317
|
+
|
|
318
|
+
rb_include_module(cClassifierVector, rb_mEnumerable);
|
|
319
|
+
}
|
data/lib/classifier/bayes.rb
CHANGED
|
@@ -4,23 +4,39 @@
|
|
|
4
4
|
# Copyright:: Copyright (c) 2005 Lucas Carlson
|
|
5
5
|
# License:: LGPL
|
|
6
6
|
|
|
7
|
+
require 'json'
|
|
8
|
+
require 'mutex_m'
|
|
9
|
+
|
|
7
10
|
module Classifier
|
|
8
11
|
class Bayes
|
|
12
|
+
include Mutex_m
|
|
13
|
+
|
|
9
14
|
# @rbs @categories: Hash[Symbol, Hash[Symbol, Integer]]
|
|
10
15
|
# @rbs @total_words: Integer
|
|
11
16
|
# @rbs @category_counts: Hash[Symbol, Integer]
|
|
12
17
|
# @rbs @category_word_count: Hash[Symbol, Integer]
|
|
18
|
+
# @rbs @cached_training_count: Float?
|
|
19
|
+
# @rbs @cached_vocab_size: Integer?
|
|
20
|
+
# @rbs @dirty: bool
|
|
21
|
+
# @rbs @storage: Storage::Base?
|
|
22
|
+
|
|
23
|
+
attr_accessor :storage
|
|
13
24
|
|
|
14
25
|
# The class can be created with one or more categories, each of which will be
|
|
15
26
|
# initialized and given a training method. E.g.,
|
|
16
27
|
# b = Classifier::Bayes.new 'Interesting', 'Uninteresting', 'Spam'
|
|
17
28
|
# @rbs (*String | Symbol) -> void
|
|
18
29
|
def initialize(*categories)
|
|
30
|
+
super()
|
|
19
31
|
@categories = {}
|
|
20
32
|
categories.each { |category| @categories[category.prepare_category_name] = {} }
|
|
21
33
|
@total_words = 0
|
|
22
34
|
@category_counts = Hash.new(0)
|
|
23
35
|
@category_word_count = Hash.new(0)
|
|
36
|
+
@cached_training_count = nil
|
|
37
|
+
@cached_vocab_size = nil
|
|
38
|
+
@dirty = false
|
|
39
|
+
@storage = nil
|
|
24
40
|
end
|
|
25
41
|
|
|
26
42
|
# Provides a general training method for all categories specified in Bayes#new
|
|
@@ -33,12 +49,17 @@ module Classifier
|
|
|
33
49
|
# @rbs (String | Symbol, String) -> void
|
|
34
50
|
def train(category, text)
|
|
35
51
|
category = category.prepare_category_name
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@
|
|
40
|
-
@
|
|
41
|
-
|
|
52
|
+
word_hash = text.word_hash
|
|
53
|
+
synchronize do
|
|
54
|
+
invalidate_caches
|
|
55
|
+
@dirty = true
|
|
56
|
+
@category_counts[category] += 1
|
|
57
|
+
word_hash.each do |word, count|
|
|
58
|
+
@categories[category][word] ||= 0
|
|
59
|
+
@categories[category][word] += count
|
|
60
|
+
@total_words += count
|
|
61
|
+
@category_word_count[category] += count
|
|
62
|
+
end
|
|
42
63
|
end
|
|
43
64
|
end
|
|
44
65
|
|
|
@@ -53,19 +74,24 @@ module Classifier
|
|
|
53
74
|
# @rbs (String | Symbol, String) -> void
|
|
54
75
|
def untrain(category, text)
|
|
55
76
|
category = category.prepare_category_name
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
@categories[category]
|
|
65
|
-
|
|
77
|
+
word_hash = text.word_hash
|
|
78
|
+
synchronize do
|
|
79
|
+
invalidate_caches
|
|
80
|
+
@dirty = true
|
|
81
|
+
@category_counts[category] -= 1
|
|
82
|
+
word_hash.each do |word, count|
|
|
83
|
+
next unless @total_words >= 0
|
|
84
|
+
|
|
85
|
+
orig = @categories[category][word] || 0
|
|
86
|
+
@categories[category][word] ||= 0
|
|
87
|
+
@categories[category][word] -= count
|
|
88
|
+
if @categories[category][word] <= 0
|
|
89
|
+
@categories[category].delete(word)
|
|
90
|
+
count = orig
|
|
91
|
+
end
|
|
92
|
+
@category_word_count[category] -= count if @category_word_count[category] >= count
|
|
93
|
+
@total_words -= count
|
|
66
94
|
end
|
|
67
|
-
@category_word_count[category] -= count if @category_word_count[category] >= count
|
|
68
|
-
@total_words -= count
|
|
69
95
|
end
|
|
70
96
|
end
|
|
71
97
|
|
|
@@ -77,17 +103,19 @@ module Classifier
|
|
|
77
103
|
# @rbs (String) -> Hash[String, Float]
|
|
78
104
|
def classifications(text)
|
|
79
105
|
words = text.word_hash.keys
|
|
80
|
-
|
|
81
|
-
|
|
106
|
+
synchronize do
|
|
107
|
+
training_count = cached_training_count
|
|
108
|
+
vocab_size = cached_vocab_size
|
|
82
109
|
|
|
83
|
-
|
|
84
|
-
|
|
110
|
+
@categories.to_h do |category, category_words|
|
|
111
|
+
smoothed_total = ((@category_word_count[category] || 0) + vocab_size).to_f
|
|
85
112
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
113
|
+
# Laplace smoothing: P(word|category) = (count + α) / (total + α * V)
|
|
114
|
+
word_score = words.sum { |w| Math.log(((category_words[w] || 0) + 1) / smoothed_total) }
|
|
115
|
+
prior_score = Math.log((@category_counts[category] || 0.1) / training_count)
|
|
89
116
|
|
|
90
|
-
|
|
117
|
+
[category.to_s, word_score + prior_score]
|
|
118
|
+
end
|
|
91
119
|
end
|
|
92
120
|
end
|
|
93
121
|
|
|
@@ -104,6 +132,119 @@ module Classifier
|
|
|
104
132
|
best.first.to_s
|
|
105
133
|
end
|
|
106
134
|
|
|
135
|
+
# Returns a hash representation of the classifier state.
|
|
136
|
+
# This can be converted to JSON or used directly.
|
|
137
|
+
#
|
|
138
|
+
# @rbs () -> untyped
|
|
139
|
+
def as_json(*)
|
|
140
|
+
{
|
|
141
|
+
version: 1,
|
|
142
|
+
type: 'bayes',
|
|
143
|
+
categories: @categories.transform_keys(&:to_s).transform_values { |v| v.transform_keys(&:to_s) },
|
|
144
|
+
total_words: @total_words,
|
|
145
|
+
category_counts: @category_counts.transform_keys(&:to_s),
|
|
146
|
+
category_word_count: @category_word_count.transform_keys(&:to_s)
|
|
147
|
+
}
|
|
148
|
+
end
|
|
149
|
+
|
|
150
|
+
# Serializes the classifier state to a JSON string.
|
|
151
|
+
# This can be saved to a file and later loaded with Bayes.from_json.
|
|
152
|
+
#
|
|
153
|
+
# @rbs () -> String
|
|
154
|
+
def to_json(*)
|
|
155
|
+
as_json.to_json
|
|
156
|
+
end
|
|
157
|
+
|
|
158
|
+
# Loads a classifier from a JSON string or a Hash created by #to_json or #as_json.
|
|
159
|
+
#
|
|
160
|
+
# @rbs (String | Hash[String, untyped]) -> Bayes
|
|
161
|
+
def self.from_json(json)
|
|
162
|
+
data = json.is_a?(String) ? JSON.parse(json) : json
|
|
163
|
+
raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'bayes'
|
|
164
|
+
|
|
165
|
+
instance = allocate
|
|
166
|
+
instance.send(:restore_state, data)
|
|
167
|
+
instance
|
|
168
|
+
end
|
|
169
|
+
|
|
170
|
+
# Saves the classifier to the configured storage.
|
|
171
|
+
# Raises ArgumentError if no storage is configured.
|
|
172
|
+
#
|
|
173
|
+
# @rbs () -> void
|
|
174
|
+
def save
|
|
175
|
+
raise ArgumentError, 'No storage configured. Use save_to_file(path) or set storage=' unless storage
|
|
176
|
+
|
|
177
|
+
storage.write(to_json)
|
|
178
|
+
@dirty = false
|
|
179
|
+
end
|
|
180
|
+
|
|
181
|
+
# Saves the classifier state to a file (legacy API).
|
|
182
|
+
#
|
|
183
|
+
# @rbs (String) -> Integer
|
|
184
|
+
def save_to_file(path)
|
|
185
|
+
result = File.write(path, to_json)
|
|
186
|
+
@dirty = false
|
|
187
|
+
result
|
|
188
|
+
end
|
|
189
|
+
|
|
190
|
+
# Reloads the classifier from the configured storage.
|
|
191
|
+
# Raises UnsavedChangesError if there are unsaved changes.
|
|
192
|
+
# Use reload! to force reload and discard changes.
|
|
193
|
+
#
|
|
194
|
+
# @rbs () -> self
|
|
195
|
+
def reload
|
|
196
|
+
raise ArgumentError, 'No storage configured' unless storage
|
|
197
|
+
raise UnsavedChangesError, 'Unsaved changes would be lost. Call save first or use reload!' if @dirty
|
|
198
|
+
|
|
199
|
+
data = storage.read
|
|
200
|
+
raise StorageError, 'No saved state found' unless data
|
|
201
|
+
|
|
202
|
+
restore_from_json(data)
|
|
203
|
+
@dirty = false
|
|
204
|
+
self
|
|
205
|
+
end
|
|
206
|
+
|
|
207
|
+
# Force reloads the classifier from storage, discarding any unsaved changes.
|
|
208
|
+
#
|
|
209
|
+
# @rbs () -> self
|
|
210
|
+
def reload!
|
|
211
|
+
raise ArgumentError, 'No storage configured' unless storage
|
|
212
|
+
|
|
213
|
+
data = storage.read
|
|
214
|
+
raise StorageError, 'No saved state found' unless data
|
|
215
|
+
|
|
216
|
+
restore_from_json(data)
|
|
217
|
+
@dirty = false
|
|
218
|
+
self
|
|
219
|
+
end
|
|
220
|
+
|
|
221
|
+
# Returns true if there are unsaved changes.
|
|
222
|
+
#
|
|
223
|
+
# @rbs () -> bool
|
|
224
|
+
def dirty?
|
|
225
|
+
@dirty
|
|
226
|
+
end
|
|
227
|
+
|
|
228
|
+
# Loads a classifier from the configured storage.
|
|
229
|
+
# The storage is set on the returned instance.
|
|
230
|
+
#
|
|
231
|
+
# @rbs (storage: Storage::Base) -> Bayes
|
|
232
|
+
def self.load(storage:)
|
|
233
|
+
data = storage.read
|
|
234
|
+
raise StorageError, 'No saved state found' unless data
|
|
235
|
+
|
|
236
|
+
instance = from_json(data)
|
|
237
|
+
instance.storage = storage
|
|
238
|
+
instance
|
|
239
|
+
end
|
|
240
|
+
|
|
241
|
+
# Loads a classifier from a file (legacy API).
|
|
242
|
+
#
|
|
243
|
+
# @rbs (String) -> Bayes
|
|
244
|
+
def self.load_from_file(path)
|
|
245
|
+
from_json(File.read(path))
|
|
246
|
+
end
|
|
247
|
+
|
|
107
248
|
#
|
|
108
249
|
# Provides training and untraining methods for the categories specified in Bayes#new
|
|
109
250
|
# For example:
|
|
@@ -134,7 +275,7 @@ module Classifier
|
|
|
134
275
|
#
|
|
135
276
|
# @rbs () -> Array[String]
|
|
136
277
|
def categories
|
|
137
|
-
@categories.keys.collect(&:to_s)
|
|
278
|
+
synchronize { @categories.keys.collect(&:to_s) }
|
|
138
279
|
end
|
|
139
280
|
|
|
140
281
|
# Allows you to add categories to the classifier.
|
|
@@ -148,11 +289,31 @@ module Classifier
|
|
|
148
289
|
#
|
|
149
290
|
# @rbs (String | Symbol) -> Hash[Symbol, Integer]
|
|
150
291
|
def add_category(category)
|
|
151
|
-
|
|
292
|
+
synchronize do
|
|
293
|
+
invalidate_caches
|
|
294
|
+
@dirty = true
|
|
295
|
+
@categories[category.prepare_category_name] = {}
|
|
296
|
+
end
|
|
152
297
|
end
|
|
153
298
|
|
|
154
299
|
alias append_category add_category
|
|
155
300
|
|
|
301
|
+
# Custom marshal serialization to exclude mutex state
|
|
302
|
+
# @rbs () -> Array[untyped]
|
|
303
|
+
def marshal_dump
|
|
304
|
+
[@categories, @total_words, @category_counts, @category_word_count, @dirty]
|
|
305
|
+
end
|
|
306
|
+
|
|
307
|
+
# Custom marshal deserialization to recreate mutex
|
|
308
|
+
# @rbs (Array[untyped]) -> void
|
|
309
|
+
def marshal_load(data)
|
|
310
|
+
mu_initialize
|
|
311
|
+
@categories, @total_words, @category_counts, @category_word_count, @dirty = data
|
|
312
|
+
@cached_training_count = nil
|
|
313
|
+
@cached_vocab_size = nil
|
|
314
|
+
@storage = nil
|
|
315
|
+
end
|
|
316
|
+
|
|
156
317
|
# Allows you to remove categories from the classifier.
|
|
157
318
|
# For example:
|
|
158
319
|
# b.remove_category "Spam"
|
|
@@ -164,13 +325,72 @@ module Classifier
|
|
|
164
325
|
# @rbs (String | Symbol) -> void
|
|
165
326
|
def remove_category(category)
|
|
166
327
|
category = category.prepare_category_name
|
|
167
|
-
|
|
328
|
+
synchronize do
|
|
329
|
+
raise StandardError, "No such category: #{category}" unless @categories.key?(category)
|
|
330
|
+
|
|
331
|
+
invalidate_caches
|
|
332
|
+
@dirty = true
|
|
333
|
+
@total_words -= @category_word_count[category].to_i
|
|
334
|
+
|
|
335
|
+
@categories.delete(category)
|
|
336
|
+
@category_counts.delete(category)
|
|
337
|
+
@category_word_count.delete(category)
|
|
338
|
+
end
|
|
339
|
+
end
|
|
340
|
+
|
|
341
|
+
private
|
|
342
|
+
|
|
343
|
+
# Restores classifier state from a JSON string (used by reload)
|
|
344
|
+
# @rbs (String) -> void
|
|
345
|
+
def restore_from_json(json)
|
|
346
|
+
data = JSON.parse(json)
|
|
347
|
+
raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'bayes'
|
|
348
|
+
|
|
349
|
+
synchronize do
|
|
350
|
+
restore_state(data)
|
|
351
|
+
end
|
|
352
|
+
end
|
|
353
|
+
|
|
354
|
+
# Restores classifier state from a hash (used by from_json)
|
|
355
|
+
# @rbs (Hash[String, untyped]) -> void
|
|
356
|
+
def restore_state(data)
|
|
357
|
+
mu_initialize
|
|
358
|
+
@categories = {} #: Hash[Symbol, Hash[Symbol, Integer]]
|
|
359
|
+
@total_words = data['total_words']
|
|
360
|
+
@category_counts = Hash.new(0) #: Hash[Symbol, Integer]
|
|
361
|
+
@category_word_count = Hash.new(0) #: Hash[Symbol, Integer]
|
|
362
|
+
@cached_training_count = nil
|
|
363
|
+
@cached_vocab_size = nil
|
|
364
|
+
@dirty = false
|
|
365
|
+
@storage = nil
|
|
366
|
+
|
|
367
|
+
data['categories'].each do |cat_name, words|
|
|
368
|
+
@categories[cat_name.to_sym] = words.transform_keys(&:to_sym)
|
|
369
|
+
end
|
|
168
370
|
|
|
169
|
-
|
|
371
|
+
data['category_counts'].each do |cat_name, count|
|
|
372
|
+
@category_counts[cat_name.to_sym] = count
|
|
373
|
+
end
|
|
374
|
+
|
|
375
|
+
data['category_word_count'].each do |cat_name, count|
|
|
376
|
+
@category_word_count[cat_name.to_sym] = count
|
|
377
|
+
end
|
|
378
|
+
end
|
|
379
|
+
|
|
380
|
+
# @rbs () -> void
|
|
381
|
+
def invalidate_caches
|
|
382
|
+
@cached_training_count = nil
|
|
383
|
+
@cached_vocab_size = nil
|
|
384
|
+
end
|
|
385
|
+
|
|
386
|
+
# @rbs () -> Float
|
|
387
|
+
def cached_training_count
|
|
388
|
+
@cached_training_count ||= @category_counts.values.sum.to_f
|
|
389
|
+
end
|
|
170
390
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
@
|
|
391
|
+
# @rbs () -> Integer
|
|
392
|
+
def cached_vocab_size
|
|
393
|
+
@cached_vocab_size ||= [@categories.values.flat_map(&:keys).uniq.size, 1].max
|
|
174
394
|
end
|
|
175
395
|
end
|
|
176
396
|
end
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# rbs_inline: enabled
|
|
2
|
+
|
|
3
|
+
# Author:: Lucas Carlson (mailto:lucas@rufy.com)
|
|
4
|
+
# Copyright:: Copyright (c) 2005 Lucas Carlson
|
|
5
|
+
# License:: LGPL
|
|
6
|
+
|
|
7
|
+
module Classifier
|
|
8
|
+
# Base error class for all Classifier errors
|
|
9
|
+
class Error < StandardError; end
|
|
10
|
+
|
|
11
|
+
# Raised when reload would discard unsaved changes
|
|
12
|
+
class UnsavedChangesError < Error; end
|
|
13
|
+
|
|
14
|
+
# Raised when a storage operation fails
|
|
15
|
+
class StorageError < Error; end
|
|
16
|
+
end
|