classifier 2.2.0 → 2.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +19 -15
- data/exe/classifier +9 -0
- data/lib/classifier/cli.rb +880 -0
- data/lib/classifier/logistic_regression.rb +41 -19
- data/lib/classifier/version.rb +3 -0
- data/lib/classifier.rb +1 -0
- data/sig/classifier.rbs +3 -0
- data/sig/vendor/json.rbs +1 -0
- data/sig/vendor/optparse.rbs +19 -0
- metadata +23 -3
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: c13b0ca80981d0186f038581e896182ca112928dcada4ac23700f2a9642ca785
|
|
4
|
+
data.tar.gz: 1949dea18b6d7e06eb931d411e821757f75084c6776ea589927b1f1b89d82280
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 562474db108e959f218407bd66a58bffe1b609ea38b8db16d15f579196d41d4976d74a66b6928e195ff0f6cc5dc4b081ea75d634cefb928e25932f8867358384
|
|
7
|
+
data.tar.gz: 39f72068f85e64a6397672376f6fcc4821e1eccd9556813e4894ba62800975330051c87e88a8747f03278e6ad0e83a665ed48fcc6cc1f623cd91b657089f8dd6
|
data/README.md
CHANGED
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
|
|
7
7
|
Text classification in Ruby. Five algorithms, native performance, streaming support.
|
|
8
8
|
|
|
9
|
-
**[Documentation](https://rubyclassifier.com/docs)** · **[Tutorials](https://rubyclassifier.com/docs/tutorials)** · **[API Reference](https://
|
|
9
|
+
**[Documentation](https://rubyclassifier.com/docs)** · **[Tutorials](https://rubyclassifier.com/docs/tutorials)** · **[API Reference](https://rubydoc.info/gems/classifier)**
|
|
10
10
|
|
|
11
11
|
## Why This Library?
|
|
12
12
|
|
|
@@ -16,7 +16,7 @@ Text classification in Ruby. Five algorithms, native performance, streaming supp
|
|
|
16
16
|
| **Incremental LSI** | ✅ Brand's algorithm (no rebuild) | ❌ Full SVD rebuild on every add |
|
|
17
17
|
| **LSI Performance** | ✅ Native C extension (5-50x faster) | ❌ Pure Ruby or requires GSL |
|
|
18
18
|
| **Streaming** | ✅ Train on multi-GB datasets | ❌ Must load all data in memory |
|
|
19
|
-
| **Persistence** | ✅ Pluggable (file, Redis, S3) | ❌ Marshal only |
|
|
19
|
+
| **Persistence** | ✅ Pluggable (file, Redis, S3, SQL, Custom) | ❌ Marshal only |
|
|
20
20
|
|
|
21
21
|
## Installation
|
|
22
22
|
|
|
@@ -30,8 +30,10 @@ gem 'classifier'
|
|
|
30
30
|
|
|
31
31
|
```ruby
|
|
32
32
|
classifier = Classifier::Bayes.new(:spam, :ham)
|
|
33
|
-
classifier.train(spam: "Buy cheap
|
|
34
|
-
classifier.
|
|
33
|
+
classifier.train(spam: "Buy viagra cheap pills now")
|
|
34
|
+
classifier.train(spam: "You won million dollars prize")
|
|
35
|
+
classifier.train(ham: ["Meeting tomorrow at 3pm", "Quarterly report attached"])
|
|
36
|
+
classifier.classify("Cheap pills!") # => "Spam"
|
|
35
37
|
```
|
|
36
38
|
[Bayesian Guide →](https://rubyclassifier.com/docs/guides/bayes/basics)
|
|
37
39
|
|
|
@@ -39,8 +41,9 @@ classifier.classify "You've won a prize!" # => "Spam"
|
|
|
39
41
|
|
|
40
42
|
```ruby
|
|
41
43
|
classifier = Classifier::LogisticRegression.new(:positive, :negative)
|
|
42
|
-
classifier.train(positive: "
|
|
43
|
-
classifier.
|
|
44
|
+
classifier.train(positive: "love amazing great wonderful")
|
|
45
|
+
classifier.train(negative: "hate terrible awful bad")
|
|
46
|
+
classifier.classify("I love it!") # => "Positive"
|
|
44
47
|
```
|
|
45
48
|
[Logistic Regression Guide →](https://rubyclassifier.com/docs/guides/logisticregression/basics)
|
|
46
49
|
|
|
@@ -48,8 +51,8 @@ classifier.classify "Loved it!" # => "Positive"
|
|
|
48
51
|
|
|
49
52
|
```ruby
|
|
50
53
|
lsi = Classifier::LSI.new
|
|
51
|
-
lsi.add(
|
|
52
|
-
lsi.classify
|
|
54
|
+
lsi.add(dog: "dog puppy canine bark fetch", cat: "cat kitten feline meow purr")
|
|
55
|
+
lsi.classify("My puppy barks") # => "dog"
|
|
53
56
|
```
|
|
54
57
|
[LSI Guide →](https://rubyclassifier.com/docs/guides/lsi/basics)
|
|
55
58
|
|
|
@@ -57,8 +60,9 @@ lsi.classify "My puppy is playful" # => "pets"
|
|
|
57
60
|
|
|
58
61
|
```ruby
|
|
59
62
|
knn = Classifier::KNN.new(k: 3)
|
|
60
|
-
|
|
61
|
-
|
|
63
|
+
%w[laptop coding software developer programming].each { |w| knn.add(tech: w) }
|
|
64
|
+
%w[football basketball soccer goal team].each { |w| knn.add(sports: w) }
|
|
65
|
+
knn.classify("programming code") # => "tech"
|
|
62
66
|
```
|
|
63
67
|
[k-Nearest Neighbors Guide →](https://rubyclassifier.com/docs/guides/knn/basics)
|
|
64
68
|
|
|
@@ -66,8 +70,8 @@ knn.classify "Claim your prize" # => "spam"
|
|
|
66
70
|
|
|
67
71
|
```ruby
|
|
68
72
|
tfidf = Classifier::TFIDF.new
|
|
69
|
-
tfidf.fit(["
|
|
70
|
-
tfidf.transform("
|
|
73
|
+
tfidf.fit(["Ruby is great", "Python is great", "Ruby on Rails"])
|
|
74
|
+
tfidf.transform("Ruby programming") # => {:rubi => 1.0}
|
|
71
75
|
```
|
|
72
76
|
[TF-IDF Guide →](https://rubyclassifier.com/docs/guides/tfidf/basics)
|
|
73
77
|
|
|
@@ -87,7 +91,7 @@ lsi.add(tech: "Go is fast")
|
|
|
87
91
|
lsi.add(tech: "Rust is safe")
|
|
88
92
|
```
|
|
89
93
|
|
|
90
|
-
[Learn more →](https://rubyclassifier.com/docs/guides/lsi/
|
|
94
|
+
[Learn more →](https://rubyclassifier.com/docs/guides/lsi/basics)
|
|
91
95
|
|
|
92
96
|
### Persistence
|
|
93
97
|
|
|
@@ -98,7 +102,7 @@ classifier.save
|
|
|
98
102
|
loaded = Classifier::Bayes.load(storage: classifier.storage)
|
|
99
103
|
```
|
|
100
104
|
|
|
101
|
-
[Learn more →](https://rubyclassifier.com/docs/guides/persistence)
|
|
105
|
+
[Learn more →](https://rubyclassifier.com/docs/guides/persistence/basics)
|
|
102
106
|
|
|
103
107
|
### Streaming Training
|
|
104
108
|
|
|
@@ -106,7 +110,7 @@ loaded = Classifier::Bayes.load(storage: classifier.storage)
|
|
|
106
110
|
classifier.train_from_stream(:spam, File.open("spam_corpus.txt"))
|
|
107
111
|
```
|
|
108
112
|
|
|
109
|
-
[Learn more →](https://rubyclassifier.com/docs/
|
|
113
|
+
[Learn more →](https://rubyclassifier.com/docs/tutorials/streaming-training)
|
|
110
114
|
|
|
111
115
|
## Performance
|
|
112
116
|
|
data/exe/classifier
ADDED
|
@@ -0,0 +1,880 @@
|
|
|
1
|
+
# rbs_inline: enabled
|
|
2
|
+
|
|
3
|
+
require 'json'
|
|
4
|
+
require 'optparse'
|
|
5
|
+
require 'net/http'
|
|
6
|
+
require 'uri'
|
|
7
|
+
require 'fileutils'
|
|
8
|
+
require 'classifier'
|
|
9
|
+
|
|
10
|
+
module Classifier
|
|
11
|
+
class CLI
|
|
12
|
+
# @rbs @args: Array[String]
|
|
13
|
+
# @rbs @stdin: String?
|
|
14
|
+
# @rbs @options: Hash[Symbol, untyped]
|
|
15
|
+
# @rbs @output: Array[String]
|
|
16
|
+
# @rbs @error: Array[String]
|
|
17
|
+
# @rbs @exit_code: Integer
|
|
18
|
+
# @rbs @parser: OptionParser
|
|
19
|
+
|
|
20
|
+
CLASSIFIER_TYPES = {
|
|
21
|
+
'bayes' => :bayes,
|
|
22
|
+
'lsi' => :lsi,
|
|
23
|
+
'knn' => :knn,
|
|
24
|
+
'lr' => :logistic_regression,
|
|
25
|
+
'logistic_regression' => :logistic_regression
|
|
26
|
+
}.freeze
|
|
27
|
+
|
|
28
|
+
DEFAULT_REGISTRY = ENV.fetch('CLASSIFIER_REGISTRY', 'cardmagic/classifier-models') #: String
|
|
29
|
+
CACHE_DIR = ENV.fetch('CLASSIFIER_CACHE', File.expand_path('~/.classifier')) #: String
|
|
30
|
+
|
|
31
|
+
def initialize(args, stdin: nil)
|
|
32
|
+
@args = args.dup
|
|
33
|
+
@stdin = stdin
|
|
34
|
+
@options = {
|
|
35
|
+
model: ENV.fetch('CLASSIFIER_MODEL', './classifier.json'),
|
|
36
|
+
type: ENV.fetch('CLASSIFIER_TYPE', 'bayes'),
|
|
37
|
+
probabilities: false,
|
|
38
|
+
quiet: false,
|
|
39
|
+
count: 10,
|
|
40
|
+
k: 5,
|
|
41
|
+
weighted: false,
|
|
42
|
+
learning_rate: nil,
|
|
43
|
+
regularization: nil,
|
|
44
|
+
max_iterations: nil,
|
|
45
|
+
remote: nil,
|
|
46
|
+
output_path: nil
|
|
47
|
+
}
|
|
48
|
+
@output = [] #: Array[String]
|
|
49
|
+
@error = [] #: Array[String]
|
|
50
|
+
@exit_code = 0
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
def run
|
|
54
|
+
parse_options
|
|
55
|
+
execute_command
|
|
56
|
+
{ output: @output.join("\n"), error: @error.join("\n"), exit_code: @exit_code }
|
|
57
|
+
rescue OptionParser::InvalidOption, OptionParser::MissingArgument, OptionParser::InvalidArgument => e
|
|
58
|
+
@error << "Error: #{e.message}"
|
|
59
|
+
@exit_code = 2
|
|
60
|
+
{ output: @output.join("\n"), error: @error.join("\n"), exit_code: @exit_code }
|
|
61
|
+
rescue StandardError => e
|
|
62
|
+
@error << "Error: #{e.message}"
|
|
63
|
+
@exit_code = 1
|
|
64
|
+
{ output: @output.join("\n"), error: @error.join("\n"), exit_code: @exit_code }
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
private
|
|
68
|
+
|
|
69
|
+
def parse_options
|
|
70
|
+
@parser = OptionParser.new do |opts|
|
|
71
|
+
opts.banner = 'Usage: classifier [options] [command] [arguments]'
|
|
72
|
+
opts.separator ''
|
|
73
|
+
opts.separator 'Commands:'
|
|
74
|
+
opts.separator ' train <category> [files...] Train a category from files or stdin'
|
|
75
|
+
opts.separator ' info Show model information'
|
|
76
|
+
opts.separator ' fit Fit the model (logistic regression)'
|
|
77
|
+
opts.separator ' search <query> Semantic search (LSI only)'
|
|
78
|
+
opts.separator ' related <item> Find related documents (LSI only)'
|
|
79
|
+
opts.separator ' models [registry] List models in registry'
|
|
80
|
+
opts.separator ' pull <model> Download model from registry'
|
|
81
|
+
opts.separator ' push <file> Contribute model to registry'
|
|
82
|
+
opts.separator ' <text> Classify text (default action)'
|
|
83
|
+
opts.separator ''
|
|
84
|
+
opts.separator 'Options:'
|
|
85
|
+
|
|
86
|
+
opts.on('-f', '--file FILE', 'Model file (default: ./classifier.json)') do |file|
|
|
87
|
+
@options[:model] = file
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
opts.on('-m', '--model TYPE', 'Classifier model: bayes, lsi, knn, lr (default: bayes)') do |type|
|
|
91
|
+
unless CLASSIFIER_TYPES.key?(type)
|
|
92
|
+
raise OptionParser::InvalidArgument, "Unknown classifier model: #{type}. Valid models: #{CLASSIFIER_TYPES.keys.join(', ')}"
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
@options[:type] = type
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
opts.on('-r', '--remote MODEL', 'Use remote model: name or @user/repo:name') do |model|
|
|
99
|
+
@options[:remote] = model
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
opts.on('-o', '--output FILE', 'Output path for pull command') do |file|
|
|
103
|
+
@options[:output_path] = file
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
opts.on('-p', 'Show probabilities') do
|
|
107
|
+
@options[:probabilities] = true
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
opts.on('-n', '--count N', Integer, 'Number of results for search/related (default: 10)') do |n|
|
|
111
|
+
@options[:count] = n
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
opts.on('-k', '--neighbors N', Integer, 'Number of neighbors for KNN (default: 5)') do |n|
|
|
115
|
+
@options[:k] = n
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
opts.on('--weighted', 'Use distance-weighted voting for KNN') do
|
|
119
|
+
@options[:weighted] = true
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
opts.on('--learning-rate N', Float, 'Learning rate for logistic regression (default: 0.1)') do |n|
|
|
123
|
+
@options[:learning_rate] = n
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
opts.on('--regularization N', Float, 'L2 regularization for logistic regression (default: 0.01)') do |n|
|
|
127
|
+
@options[:regularization] = n
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
opts.on('--max-iterations N', Integer, 'Max iterations for logistic regression (default: 100)') do |n|
|
|
131
|
+
@options[:max_iterations] = n
|
|
132
|
+
end
|
|
133
|
+
|
|
134
|
+
opts.on('-q', 'Quiet mode') do
|
|
135
|
+
@options[:quiet] = true
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
opts.on('--local', 'List locally cached models (for models command)') do
|
|
139
|
+
@options[:local] = true
|
|
140
|
+
end
|
|
141
|
+
|
|
142
|
+
opts.on('-v', '--version', 'Show version') do
|
|
143
|
+
@output << Classifier::VERSION
|
|
144
|
+
@exit_code = 0
|
|
145
|
+
throw :done
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
opts.on('-h', '--help', 'Show help') do
|
|
149
|
+
@output << opts.to_s
|
|
150
|
+
@exit_code = 0
|
|
151
|
+
throw :done
|
|
152
|
+
end
|
|
153
|
+
end
|
|
154
|
+
|
|
155
|
+
catch(:done) do
|
|
156
|
+
@parser.parse!(@args)
|
|
157
|
+
end
|
|
158
|
+
end
|
|
159
|
+
|
|
160
|
+
def execute_command
|
|
161
|
+
return if @exit_code != 0 || @output.any?
|
|
162
|
+
|
|
163
|
+
command = @args.first
|
|
164
|
+
|
|
165
|
+
case command
|
|
166
|
+
when 'train'
|
|
167
|
+
command_train
|
|
168
|
+
when 'info'
|
|
169
|
+
command_info
|
|
170
|
+
when 'fit'
|
|
171
|
+
command_fit
|
|
172
|
+
when 'search'
|
|
173
|
+
command_search
|
|
174
|
+
when 'related'
|
|
175
|
+
command_related
|
|
176
|
+
when 'models'
|
|
177
|
+
command_models
|
|
178
|
+
when 'pull'
|
|
179
|
+
command_pull
|
|
180
|
+
when 'push'
|
|
181
|
+
command_push
|
|
182
|
+
else
|
|
183
|
+
command_classify
|
|
184
|
+
end
|
|
185
|
+
end
|
|
186
|
+
|
|
187
|
+
def command_train
|
|
188
|
+
@args.shift # remove 'train'
|
|
189
|
+
category = @args.shift
|
|
190
|
+
|
|
191
|
+
unless category
|
|
192
|
+
@error << 'Error: category required for train command'
|
|
193
|
+
@exit_code = 2
|
|
194
|
+
return
|
|
195
|
+
end
|
|
196
|
+
|
|
197
|
+
classifier = load_or_create_classifier
|
|
198
|
+
|
|
199
|
+
if classifier.is_a?(LSI) && @args.any?
|
|
200
|
+
train_lsi_from_files(classifier, category, @args)
|
|
201
|
+
save_classifier(classifier)
|
|
202
|
+
return
|
|
203
|
+
end
|
|
204
|
+
|
|
205
|
+
text = read_training_input
|
|
206
|
+
if text.empty?
|
|
207
|
+
@error << 'Error: no training data provided'
|
|
208
|
+
@exit_code = 2
|
|
209
|
+
return
|
|
210
|
+
end
|
|
211
|
+
|
|
212
|
+
train_classifier(classifier, category, text)
|
|
213
|
+
save_classifier(classifier)
|
|
214
|
+
end
|
|
215
|
+
|
|
216
|
+
def command_info
|
|
217
|
+
unless File.exist?(@options[:model])
|
|
218
|
+
@error << "Error: model not found at #{@options[:model]}"
|
|
219
|
+
@exit_code = 1
|
|
220
|
+
return
|
|
221
|
+
end
|
|
222
|
+
|
|
223
|
+
classifier = load_classifier
|
|
224
|
+
info = build_model_info(classifier)
|
|
225
|
+
@output << JSON.pretty_generate(info)
|
|
226
|
+
end
|
|
227
|
+
|
|
228
|
+
def build_model_info(classifier)
|
|
229
|
+
info = { file: @options[:model], type: classifier_type_name(classifier) }
|
|
230
|
+
add_common_info(info, classifier)
|
|
231
|
+
add_classifier_specific_info(info, classifier)
|
|
232
|
+
info
|
|
233
|
+
end
|
|
234
|
+
|
|
235
|
+
def add_common_info(info, classifier)
|
|
236
|
+
info[:categories] = classifier.categories.map(&:to_s) if classifier.respond_to?(:categories)
|
|
237
|
+
info[:training_count] = classifier.training_count if classifier.respond_to?(:training_count)
|
|
238
|
+
info[:vocab_size] = classifier.vocab_size if classifier.respond_to?(:vocab_size)
|
|
239
|
+
info[:fitted] = classifier.fitted? if classifier.respond_to?(:fitted?)
|
|
240
|
+
end
|
|
241
|
+
|
|
242
|
+
def add_classifier_specific_info(info, classifier)
|
|
243
|
+
case classifier
|
|
244
|
+
when Bayes then add_bayes_info(info, classifier)
|
|
245
|
+
when LSI then add_lsi_info(info, classifier)
|
|
246
|
+
when KNN then add_knn_info(info, classifier)
|
|
247
|
+
end
|
|
248
|
+
end
|
|
249
|
+
|
|
250
|
+
def add_bayes_info(info, classifier)
|
|
251
|
+
categories_data = classifier.instance_variable_get(:@categories)
|
|
252
|
+
info[:category_stats] = classifier.categories.to_h do |cat|
|
|
253
|
+
cat_data = categories_data[cat.to_sym] || {}
|
|
254
|
+
[cat.to_s, { unique_words: cat_data.size, total_words: cat_data.values.sum }]
|
|
255
|
+
end
|
|
256
|
+
end
|
|
257
|
+
|
|
258
|
+
def add_lsi_info(info, classifier)
|
|
259
|
+
info[:documents] = classifier.items.size
|
|
260
|
+
info[:items] = classifier.items
|
|
261
|
+
categories = classifier.items.map { |item| classifier.categories_for(item) }.flatten.uniq
|
|
262
|
+
info[:categories] = categories.map(&:to_s) unless categories.empty?
|
|
263
|
+
end
|
|
264
|
+
|
|
265
|
+
def add_knn_info(info, classifier)
|
|
266
|
+
data = classifier.instance_variable_get(:@data) || []
|
|
267
|
+
info[:documents] = data.size
|
|
268
|
+
categories = data.map { |d| d[:category] }.uniq
|
|
269
|
+
info[:categories] = categories.map(&:to_s) unless categories.empty?
|
|
270
|
+
end
|
|
271
|
+
|
|
272
|
+
def command_fit
|
|
273
|
+
unless File.exist?(@options[:model])
|
|
274
|
+
@error << "Error: model not found at #{@options[:model]}"
|
|
275
|
+
@exit_code = 1
|
|
276
|
+
return
|
|
277
|
+
end
|
|
278
|
+
|
|
279
|
+
classifier = load_classifier
|
|
280
|
+
|
|
281
|
+
unless classifier.respond_to?(:fit)
|
|
282
|
+
@output << 'Model does not require fitting' unless @options[:quiet]
|
|
283
|
+
return
|
|
284
|
+
end
|
|
285
|
+
|
|
286
|
+
classifier.fit
|
|
287
|
+
save_classifier(classifier)
|
|
288
|
+
@output << 'Model fitted successfully' unless @options[:quiet]
|
|
289
|
+
end
|
|
290
|
+
|
|
291
|
+
def command_search
|
|
292
|
+
@args.shift # remove 'search'
|
|
293
|
+
|
|
294
|
+
unless File.exist?(@options[:model])
|
|
295
|
+
@error << "Error: model not found at #{@options[:model]}"
|
|
296
|
+
@exit_code = 1
|
|
297
|
+
return
|
|
298
|
+
end
|
|
299
|
+
|
|
300
|
+
classifier = load_classifier
|
|
301
|
+
|
|
302
|
+
unless classifier.is_a?(LSI)
|
|
303
|
+
@error << 'Error: search requires LSI model (use -t lsi)'
|
|
304
|
+
@exit_code = 1
|
|
305
|
+
return
|
|
306
|
+
end
|
|
307
|
+
|
|
308
|
+
query = @args.join(' ')
|
|
309
|
+
query = read_stdin_line if query.empty?
|
|
310
|
+
|
|
311
|
+
if query.empty?
|
|
312
|
+
@error << 'Error: search query required'
|
|
313
|
+
@exit_code = 2
|
|
314
|
+
return
|
|
315
|
+
end
|
|
316
|
+
|
|
317
|
+
results = classifier.search(query, @options[:count])
|
|
318
|
+
results.each do |item|
|
|
319
|
+
score = classifier.proximity_norms_for_content(query).find { |i, _| i == item }&.last || 0
|
|
320
|
+
@output << "#{item}:#{format('%.2f', score)}"
|
|
321
|
+
end
|
|
322
|
+
end
|
|
323
|
+
|
|
324
|
+
def command_related
|
|
325
|
+
@args.shift # remove 'related'
|
|
326
|
+
item = @args.shift
|
|
327
|
+
|
|
328
|
+
unless item
|
|
329
|
+
@error << 'Error: item required for related command'
|
|
330
|
+
@exit_code = 2
|
|
331
|
+
return
|
|
332
|
+
end
|
|
333
|
+
|
|
334
|
+
unless File.exist?(@options[:model])
|
|
335
|
+
@error << "Error: model not found at #{@options[:model]}"
|
|
336
|
+
@exit_code = 1
|
|
337
|
+
return
|
|
338
|
+
end
|
|
339
|
+
|
|
340
|
+
classifier = load_classifier
|
|
341
|
+
|
|
342
|
+
unless classifier.is_a?(LSI)
|
|
343
|
+
@error << 'Error: related requires LSI model (use -t lsi)'
|
|
344
|
+
@exit_code = 1
|
|
345
|
+
return
|
|
346
|
+
end
|
|
347
|
+
|
|
348
|
+
unless classifier.items.include?(item)
|
|
349
|
+
@error << "Error: item not found in model: #{item}"
|
|
350
|
+
@exit_code = 1
|
|
351
|
+
return
|
|
352
|
+
end
|
|
353
|
+
|
|
354
|
+
results = classifier.find_related(item, @options[:count])
|
|
355
|
+
results.each do |related_item|
|
|
356
|
+
scores = classifier.proximity_array_for_content(item)
|
|
357
|
+
score = scores.find { |i, _| i == related_item }&.last || 0
|
|
358
|
+
@output << "#{related_item}:#{format('%.2f', score)}"
|
|
359
|
+
end
|
|
360
|
+
end
|
|
361
|
+
|
|
362
|
+
def command_models
|
|
363
|
+
@args.shift # remove 'models'
|
|
364
|
+
|
|
365
|
+
if @options[:local]
|
|
366
|
+
list_local_models
|
|
367
|
+
else
|
|
368
|
+
list_remote_models
|
|
369
|
+
end
|
|
370
|
+
end
|
|
371
|
+
|
|
372
|
+
def list_remote_models
|
|
373
|
+
registry_arg = @args.shift
|
|
374
|
+
registry = parse_registry(registry_arg) || DEFAULT_REGISTRY
|
|
375
|
+
index = fetch_registry_index(registry)
|
|
376
|
+
|
|
377
|
+
return if @exit_code != 0
|
|
378
|
+
|
|
379
|
+
if index['models'].empty?
|
|
380
|
+
@output << 'No models found in registry'
|
|
381
|
+
return
|
|
382
|
+
end
|
|
383
|
+
|
|
384
|
+
index['models'].each do |name, info|
|
|
385
|
+
type = info['type'] || 'unknown'
|
|
386
|
+
size = info['size'] || 'unknown'
|
|
387
|
+
desc = info['description'] || ''
|
|
388
|
+
@output << format('%-20<name>s %<desc>s (%<type>s, %<size>s)', name: name, desc: desc.slice(0, 40), type: type, size: size)
|
|
389
|
+
end
|
|
390
|
+
end
|
|
391
|
+
|
|
392
|
+
def list_local_models
|
|
393
|
+
models_dir = File.join(CACHE_DIR, 'models')
|
|
394
|
+
|
|
395
|
+
unless Dir.exist?(models_dir)
|
|
396
|
+
@output << 'No local models found'
|
|
397
|
+
return
|
|
398
|
+
end
|
|
399
|
+
|
|
400
|
+
# Find models from default registry
|
|
401
|
+
default_models = Dir.glob(File.join(models_dir, '*.json')).map do |path|
|
|
402
|
+
{ name: File.basename(path, '.json'), registry: nil, path: path }
|
|
403
|
+
end
|
|
404
|
+
|
|
405
|
+
# Find models from custom registries (@user/repo structure)
|
|
406
|
+
custom_models = Dir.glob(File.join(models_dir, '@*', '*', '*.json')).map do |path|
|
|
407
|
+
# Extract registry from path: .../models/@user/repo/model.json
|
|
408
|
+
repo_dir = File.dirname(path)
|
|
409
|
+
user_dir = File.dirname(repo_dir)
|
|
410
|
+
registry = "#{File.basename(user_dir).delete_prefix('@')}/#{File.basename(repo_dir)}"
|
|
411
|
+
{ name: File.basename(path, '.json'), registry: registry, path: path }
|
|
412
|
+
end
|
|
413
|
+
|
|
414
|
+
models = default_models + custom_models #: Array[{name: String, registry: String?, path: String}]
|
|
415
|
+
|
|
416
|
+
if models.empty?
|
|
417
|
+
@output << 'No local models found'
|
|
418
|
+
return
|
|
419
|
+
end
|
|
420
|
+
|
|
421
|
+
models.each do |model|
|
|
422
|
+
info = load_model_info(model[:path])
|
|
423
|
+
type = info['type'] || 'unknown'
|
|
424
|
+
display_name = model[:registry] ? "@#{model[:registry]}:#{model[:name]}" : model[:name]
|
|
425
|
+
size = File.size(model[:path])
|
|
426
|
+
@output << format('%-30<name>s (%<type>s, %<size>s)', name: display_name, type: type, size: human_size(size))
|
|
427
|
+
end
|
|
428
|
+
end
|
|
429
|
+
|
|
430
|
+
def load_model_info(path)
|
|
431
|
+
JSON.parse(File.read(path))
|
|
432
|
+
rescue JSON::ParserError
|
|
433
|
+
{}
|
|
434
|
+
end
|
|
435
|
+
|
|
436
|
+
def human_size(bytes)
|
|
437
|
+
units = %w[B KB MB GB]
|
|
438
|
+
unit_index = 0
|
|
439
|
+
size = bytes.to_f
|
|
440
|
+
|
|
441
|
+
while size >= 1024 && unit_index < units.length - 1
|
|
442
|
+
size /= 1024
|
|
443
|
+
unit_index += 1
|
|
444
|
+
end
|
|
445
|
+
|
|
446
|
+
format('%<size>.1f %<unit>s', size: size, unit: units[unit_index])
|
|
447
|
+
end
|
|
448
|
+
|
|
449
|
+
def command_pull
|
|
450
|
+
@args.shift # remove 'pull'
|
|
451
|
+
model_spec = @args.shift
|
|
452
|
+
|
|
453
|
+
unless model_spec
|
|
454
|
+
@error << 'Error: model name required for pull command'
|
|
455
|
+
@exit_code = 2
|
|
456
|
+
return
|
|
457
|
+
end
|
|
458
|
+
|
|
459
|
+
registry, model_name = parse_model_spec(model_spec)
|
|
460
|
+
registry ||= DEFAULT_REGISTRY
|
|
461
|
+
|
|
462
|
+
if model_name.nil?
|
|
463
|
+
pull_all_models(registry)
|
|
464
|
+
else
|
|
465
|
+
pull_single_model(registry, model_name)
|
|
466
|
+
end
|
|
467
|
+
end
|
|
468
|
+
|
|
469
|
+
def pull_single_model(registry, model_name)
|
|
470
|
+
index = fetch_registry_index(registry)
|
|
471
|
+
return if @exit_code != 0
|
|
472
|
+
|
|
473
|
+
model_info = index['models'][model_name]
|
|
474
|
+
unless model_info
|
|
475
|
+
@error << "Error: model '#{model_name}' not found in registry #{registry}"
|
|
476
|
+
@exit_code = 1
|
|
477
|
+
return
|
|
478
|
+
end
|
|
479
|
+
|
|
480
|
+
file_path = model_info['file'] || "models/#{model_name}.json"
|
|
481
|
+
output_path = @options[:output_path] || cache_path_for(registry, model_name)
|
|
482
|
+
|
|
483
|
+
@output << "Downloading #{model_name} from #{registry}..." unless @options[:quiet]
|
|
484
|
+
|
|
485
|
+
content = fetch_github_file(registry, file_path)
|
|
486
|
+
return if @exit_code != 0
|
|
487
|
+
|
|
488
|
+
FileUtils.mkdir_p(File.dirname(output_path))
|
|
489
|
+
File.write(output_path, content)
|
|
490
|
+
|
|
491
|
+
@output << "Saved to #{output_path}" unless @options[:quiet]
|
|
492
|
+
end
|
|
493
|
+
|
|
494
|
+
def pull_all_models(registry)
|
|
495
|
+
index = fetch_registry_index(registry)
|
|
496
|
+
return if @exit_code != 0
|
|
497
|
+
|
|
498
|
+
if index['models'].empty?
|
|
499
|
+
@output << 'No models found in registry'
|
|
500
|
+
return
|
|
501
|
+
end
|
|
502
|
+
|
|
503
|
+
@output << "Downloading #{index['models'].size} models from #{registry}..." unless @options[:quiet]
|
|
504
|
+
|
|
505
|
+
index['models'].each_key do |model_name|
|
|
506
|
+
pull_single_model(registry, model_name)
|
|
507
|
+
break if @exit_code != 0
|
|
508
|
+
end
|
|
509
|
+
end
|
|
510
|
+
|
|
511
|
+
def command_push
|
|
512
|
+
@args.shift # remove 'push'
|
|
513
|
+
|
|
514
|
+
@output << 'To contribute a model to the registry:'
|
|
515
|
+
@output << ''
|
|
516
|
+
@output << '1. Fork https://github.com/cardmagic/classifier-models'
|
|
517
|
+
@output << '2. Add your model to the models/ directory'
|
|
518
|
+
@output << '3. Update models.json with your model metadata'
|
|
519
|
+
@output << '4. Create a pull request'
|
|
520
|
+
@output << ''
|
|
521
|
+
@output << 'Or use the GitHub CLI:'
|
|
522
|
+
@output << ''
|
|
523
|
+
@output << ' gh repo fork cardmagic/classifier-models --clone'
|
|
524
|
+
@output << ' cp ./classifier.json classifier-models/models/my-model.json'
|
|
525
|
+
@output << ' # Edit classifier-models/models.json to add your model'
|
|
526
|
+
@output << ' cd classifier-models && gh pr create'
|
|
527
|
+
end
|
|
528
|
+
|
|
529
|
+
def command_classify
|
|
530
|
+
text = @args.join(' ')
|
|
531
|
+
|
|
532
|
+
if @options[:remote]
|
|
533
|
+
classify_with_remote(text)
|
|
534
|
+
return
|
|
535
|
+
end
|
|
536
|
+
|
|
537
|
+
if text.empty? && ($stdin.tty? || @stdin.nil?) && !File.exist?(@options[:model])
|
|
538
|
+
show_getting_started
|
|
539
|
+
return
|
|
540
|
+
end
|
|
541
|
+
|
|
542
|
+
unless File.exist?(@options[:model])
|
|
543
|
+
@error << "Error: model not found at #{@options[:model]}"
|
|
544
|
+
@exit_code = 1
|
|
545
|
+
return
|
|
546
|
+
end
|
|
547
|
+
|
|
548
|
+
classifier = load_classifier
|
|
549
|
+
|
|
550
|
+
if text.empty?
|
|
551
|
+
lines = read_stdin_lines
|
|
552
|
+
return show_model_usage(classifier) if lines.empty?
|
|
553
|
+
|
|
554
|
+
lines.each { |line| classify_and_output(classifier, line) }
|
|
555
|
+
else
|
|
556
|
+
classify_and_output(classifier, text)
|
|
557
|
+
end
|
|
558
|
+
end
|
|
559
|
+
|
|
560
|
+
def classify_with_remote(text)
|
|
561
|
+
registry, model_name = parse_model_spec(@options[:remote])
|
|
562
|
+
registry ||= DEFAULT_REGISTRY
|
|
563
|
+
|
|
564
|
+
unless model_name
|
|
565
|
+
@error << 'Error: model name required for -r option'
|
|
566
|
+
@exit_code = 2
|
|
567
|
+
return
|
|
568
|
+
end
|
|
569
|
+
|
|
570
|
+
cached_path = cache_path_for(registry, model_name)
|
|
571
|
+
|
|
572
|
+
unless File.exist?(cached_path)
|
|
573
|
+
pull_single_model(registry, model_name)
|
|
574
|
+
return if @exit_code != 0
|
|
575
|
+
end
|
|
576
|
+
|
|
577
|
+
original_model = @options[:model]
|
|
578
|
+
@options[:model] = cached_path
|
|
579
|
+
|
|
580
|
+
begin
|
|
581
|
+
classifier = load_classifier
|
|
582
|
+
|
|
583
|
+
if text.empty?
|
|
584
|
+
lines = read_stdin_lines
|
|
585
|
+
return show_model_usage(classifier) if lines.empty?
|
|
586
|
+
|
|
587
|
+
lines.each { |line| classify_and_output(classifier, line) }
|
|
588
|
+
else
|
|
589
|
+
classify_and_output(classifier, text)
|
|
590
|
+
end
|
|
591
|
+
ensure
|
|
592
|
+
@options[:model] = original_model
|
|
593
|
+
end
|
|
594
|
+
end
|
|
595
|
+
|
|
596
|
+
# @rbs (untyped) -> void
|
|
597
|
+
def show_model_usage(classifier)
|
|
598
|
+
type = classifier_type_name(classifier)
|
|
599
|
+
cats = classifier.categories.map(&:to_s).map(&:downcase)
|
|
600
|
+
first_cat = cats.first || 'category'
|
|
601
|
+
|
|
602
|
+
@output << "Model: #{@options[:model]} (#{type})"
|
|
603
|
+
@output << "Categories: #{cats.join(', ')}"
|
|
604
|
+
@output << ''
|
|
605
|
+
@output << 'Classify text:'
|
|
606
|
+
@output << ''
|
|
607
|
+
@output << " classifier 'text to classify'"
|
|
608
|
+
@output << " echo 'text to classify' | classifier"
|
|
609
|
+
@output << ''
|
|
610
|
+
@output << 'Train more data:'
|
|
611
|
+
@output << ''
|
|
612
|
+
@output << " echo 'new example text' | classifier train #{first_cat}"
|
|
613
|
+
@output << " classifier train #{first_cat} file1.txt file2.txt"
|
|
614
|
+
@output << ''
|
|
615
|
+
@output << 'Other commands:'
|
|
616
|
+
@output << ''
|
|
617
|
+
@output << ' classifier info Show model details (JSON)'
|
|
618
|
+
end
|
|
619
|
+
|
|
620
|
+
def classify_and_output(classifier, text)
|
|
621
|
+
return if text.strip.empty?
|
|
622
|
+
|
|
623
|
+
if classifier.is_a?(LogisticRegression) && !classifier.fitted?
|
|
624
|
+
raise StandardError, "Model not fitted. Run 'classifier fit' after training."
|
|
625
|
+
end
|
|
626
|
+
|
|
627
|
+
if @options[:probabilities]
|
|
628
|
+
probs = get_probabilities(classifier, text)
|
|
629
|
+
formatted = probs.map { |cat, prob| "#{cat.downcase}:#{format('%.2f', prob)}" }.join(' ')
|
|
630
|
+
@output << formatted
|
|
631
|
+
else
|
|
632
|
+
result = classifier.classify(text)
|
|
633
|
+
@output << result.downcase
|
|
634
|
+
end
|
|
635
|
+
end
|
|
636
|
+
|
|
637
|
+
def get_probabilities(classifier, text)
|
|
638
|
+
if classifier.respond_to?(:probabilities)
|
|
639
|
+
classifier.probabilities(text)
|
|
640
|
+
elsif classifier.respond_to?(:classifications)
|
|
641
|
+
scores = classifier.classifications(text)
|
|
642
|
+
normalize_scores(scores)
|
|
643
|
+
else
|
|
644
|
+
{ classifier.classify(text) => 1.0 }
|
|
645
|
+
end
|
|
646
|
+
end
|
|
647
|
+
|
|
648
|
+
def normalize_scores(scores)
|
|
649
|
+
max_score = scores.values.max
|
|
650
|
+
exp_scores = scores.transform_values { |s| Math.exp(s - max_score) }
|
|
651
|
+
total = exp_scores.values.sum.to_f
|
|
652
|
+
exp_scores.transform_values { |s| (s / total).to_f }
|
|
653
|
+
end
|
|
654
|
+
|
|
655
|
+
def load_or_create_classifier
|
|
656
|
+
if File.exist?(@options[:model])
|
|
657
|
+
load_classifier
|
|
658
|
+
else
|
|
659
|
+
create_classifier
|
|
660
|
+
end
|
|
661
|
+
end
|
|
662
|
+
|
|
663
|
+
def load_classifier
|
|
664
|
+
json = File.read(@options[:model])
|
|
665
|
+
data = JSON.parse(json)
|
|
666
|
+
type = data['type']
|
|
667
|
+
|
|
668
|
+
case type
|
|
669
|
+
when 'bayes'
|
|
670
|
+
Bayes.from_json(data)
|
|
671
|
+
when 'lsi'
|
|
672
|
+
LSI.from_json(data)
|
|
673
|
+
when 'knn'
|
|
674
|
+
KNN.from_json(data)
|
|
675
|
+
when 'logistic_regression'
|
|
676
|
+
LogisticRegression.from_json(data)
|
|
677
|
+
else
|
|
678
|
+
raise "Unknown classifier type in model: #{type}"
|
|
679
|
+
end
|
|
680
|
+
end
|
|
681
|
+
|
|
682
|
+
def create_classifier
|
|
683
|
+
type = CLASSIFIER_TYPES[@options[:type]] || :bayes
|
|
684
|
+
|
|
685
|
+
case type
|
|
686
|
+
when :lsi
|
|
687
|
+
LSI.new(auto_rebuild: true)
|
|
688
|
+
when :knn
|
|
689
|
+
KNN.new(k: @options[:k], weighted: @options[:weighted])
|
|
690
|
+
when :logistic_regression
|
|
691
|
+
lr_opts = {} #: Hash[Symbol, untyped]
|
|
692
|
+
lr_opts[:learning_rate] = @options[:learning_rate] if @options[:learning_rate]
|
|
693
|
+
lr_opts[:regularization] = @options[:regularization] if @options[:regularization]
|
|
694
|
+
lr_opts[:max_iterations] = @options[:max_iterations] if @options[:max_iterations]
|
|
695
|
+
LogisticRegression.new(**lr_opts)
|
|
696
|
+
else # :bayes or unknown defaults to Bayes
|
|
697
|
+
Bayes.new
|
|
698
|
+
end
|
|
699
|
+
end
|
|
700
|
+
|
|
701
|
+
def train_classifier(classifier, category, text)
|
|
702
|
+
case classifier
|
|
703
|
+
when Bayes, LogisticRegression
|
|
704
|
+
classifier.add_category(category) unless classifier.categories.include?(category)
|
|
705
|
+
text.each_line { |line| classifier.train(category, line.strip) unless line.strip.empty? }
|
|
706
|
+
when LSI
|
|
707
|
+
text.each_line do |line|
|
|
708
|
+
next if line.strip.empty?
|
|
709
|
+
|
|
710
|
+
classifier.add_item(line.strip, category.to_sym)
|
|
711
|
+
end
|
|
712
|
+
when KNN
|
|
713
|
+
text.each_line do |line|
|
|
714
|
+
next if line.strip.empty?
|
|
715
|
+
|
|
716
|
+
classifier.add(category.to_sym => line.strip)
|
|
717
|
+
end
|
|
718
|
+
end
|
|
719
|
+
end
|
|
720
|
+
|
|
721
|
+
def train_lsi_from_files(classifier, category, files)
|
|
722
|
+
files.each do |file|
|
|
723
|
+
content = File.read(file)
|
|
724
|
+
classifier.add_item(file, category.to_sym) { content }
|
|
725
|
+
end
|
|
726
|
+
end
|
|
727
|
+
|
|
728
|
+
def save_classifier(classifier)
|
|
729
|
+
classifier.storage = Storage::File.new(path: @options[:model])
|
|
730
|
+
classifier.save
|
|
731
|
+
end
|
|
732
|
+
|
|
733
|
+
def classifier_type_name(classifier)
|
|
734
|
+
case classifier
|
|
735
|
+
when Bayes then 'bayes'
|
|
736
|
+
when LSI then 'lsi'
|
|
737
|
+
when KNN then 'knn'
|
|
738
|
+
when LogisticRegression then 'logistic_regression'
|
|
739
|
+
else 'unknown'
|
|
740
|
+
end
|
|
741
|
+
end
|
|
742
|
+
|
|
743
|
+
def read_training_input
|
|
744
|
+
if @args.any?
|
|
745
|
+
@args.map { |file| File.read(file) }.join("\n")
|
|
746
|
+
else
|
|
747
|
+
read_stdin
|
|
748
|
+
end
|
|
749
|
+
end
|
|
750
|
+
|
|
751
|
+
def read_stdin
|
|
752
|
+
@stdin || ($stdin.tty? ? '' : $stdin.read)
|
|
753
|
+
end
|
|
754
|
+
|
|
755
|
+
def read_stdin_line
|
|
756
|
+
(@stdin || ($stdin.tty? ? '' : $stdin.read)).to_s.strip
|
|
757
|
+
end
|
|
758
|
+
|
|
759
|
+
def read_stdin_lines
|
|
760
|
+
read_stdin.to_s.split("\n").map(&:strip).reject(&:empty?)
|
|
761
|
+
end
|
|
762
|
+
|
|
763
|
+
# @rbs () -> void
|
|
764
|
+
def show_getting_started
|
|
765
|
+
@output << 'Classifier - Text classification from the command line'
|
|
766
|
+
@output << ''
|
|
767
|
+
@output << 'Get started by training some categories:'
|
|
768
|
+
@output << ''
|
|
769
|
+
@output << ' # Train from files'
|
|
770
|
+
@output << ' classifier train spam spam_emails/*.txt'
|
|
771
|
+
@output << ' classifier train ham good_emails/*.txt'
|
|
772
|
+
@output << ''
|
|
773
|
+
@output << ' # Train from stdin'
|
|
774
|
+
@output << " echo 'buy viagra now free pills cheap meds' | classifier train spam"
|
|
775
|
+
@output << " echo 'meeting scheduled for tomorrow to discuss project' | classifier train ham"
|
|
776
|
+
@output << ''
|
|
777
|
+
@output << 'Then classify text:'
|
|
778
|
+
@output << ''
|
|
779
|
+
@output << " classifier 'free money buy now'"
|
|
780
|
+
@output << " classifier 'meeting postponed to friday'"
|
|
781
|
+
@output << ''
|
|
782
|
+
@output << 'Use LSI for semantic search:'
|
|
783
|
+
@output << ''
|
|
784
|
+
@output << " echo 'ruby is a dynamic programming language' | classifier train docs -m lsi"
|
|
785
|
+
@output << " echo 'python is great for data science' | classifier train docs -m lsi"
|
|
786
|
+
@output << " classifier search 'programming'"
|
|
787
|
+
@output << ''
|
|
788
|
+
@output << 'Options:'
|
|
789
|
+
@output << ' -f FILE Model file (default: ./classifier.json)'
|
|
790
|
+
@output << ' -m TYPE Model type: bayes, lsi, knn, lr (default: bayes)'
|
|
791
|
+
@output << ' -r MODEL Use remote model from registry'
|
|
792
|
+
@output << ' -p Show probabilities'
|
|
793
|
+
@output << ''
|
|
794
|
+
@output << 'Use pre-trained models:'
|
|
795
|
+
@output << ''
|
|
796
|
+
@output << ' classifier models List available models'
|
|
797
|
+
@output << ' classifier pull sentiment Download a model'
|
|
798
|
+
@output << " classifier -r sentiment 'I love this!' Classify with remote model"
|
|
799
|
+
@output << ''
|
|
800
|
+
@output << 'Run "classifier --help" for full usage.'
|
|
801
|
+
end
|
|
802
|
+
|
|
803
|
+
# Parse @user/repo format to extract registry
|
|
804
|
+
# @rbs (String?) -> String?
|
|
805
|
+
def parse_registry(arg)
|
|
806
|
+
return nil unless arg
|
|
807
|
+
return nil unless arg.start_with?('@')
|
|
808
|
+
|
|
809
|
+
# @user/repo format
|
|
810
|
+
arg[1..] # Remove @ prefix
|
|
811
|
+
end
|
|
812
|
+
|
|
813
|
+
# Parse model spec: name, @user/repo:name, or @user/repo (for all models)
|
|
814
|
+
# Returns [registry, model_name] where model_name is nil if pulling all
|
|
815
|
+
# @rbs (String) -> [String?, String?]
|
|
816
|
+
def parse_model_spec(spec)
|
|
817
|
+
if spec.start_with?('@')
|
|
818
|
+
# @user/repo:model or @user/repo
|
|
819
|
+
rest = spec[1..] || ''
|
|
820
|
+
if spec.include?(':')
|
|
821
|
+
parts = rest.split(':', 2)
|
|
822
|
+
[parts[0], parts[1]]
|
|
823
|
+
else
|
|
824
|
+
# @user/repo - pull all models from registry
|
|
825
|
+
[rest, nil]
|
|
826
|
+
end
|
|
827
|
+
else
|
|
828
|
+
# Just a model name from default registry
|
|
829
|
+
[nil, spec]
|
|
830
|
+
end
|
|
831
|
+
end
|
|
832
|
+
|
|
833
|
+
# Get cache path for a model
|
|
834
|
+
# @rbs (String, String) -> String
|
|
835
|
+
def cache_path_for(registry, model_name)
|
|
836
|
+
if registry == DEFAULT_REGISTRY
|
|
837
|
+
File.join(CACHE_DIR, 'models', "#{model_name}.json")
|
|
838
|
+
else
|
|
839
|
+
File.join(CACHE_DIR, 'models', "@#{registry}", "#{model_name}.json")
|
|
840
|
+
end
|
|
841
|
+
end
|
|
842
|
+
|
|
843
|
+
# Fetch models.json index from a registry
|
|
844
|
+
# @rbs (String) -> Hash[String, untyped]
|
|
845
|
+
def fetch_registry_index(registry)
|
|
846
|
+
content = fetch_github_file(registry, 'models.json')
|
|
847
|
+
return { 'models' => {} } if @exit_code != 0
|
|
848
|
+
|
|
849
|
+
JSON.parse(content)
|
|
850
|
+
rescue JSON::ParserError => e
|
|
851
|
+
@error << "Error: invalid models.json in registry: #{e.message}"
|
|
852
|
+
@exit_code = 1
|
|
853
|
+
{ 'models' => {} }
|
|
854
|
+
end
|
|
855
|
+
|
|
856
|
+
# Fetch a file from GitHub raw content
|
|
857
|
+
# @rbs (String, String) -> String
|
|
858
|
+
def fetch_github_file(registry, file_path)
|
|
859
|
+
url = "https://raw.githubusercontent.com/#{registry}/main/#{file_path}"
|
|
860
|
+
uri = URI.parse(url)
|
|
861
|
+
|
|
862
|
+
response = Net::HTTP.get_response(uri)
|
|
863
|
+
|
|
864
|
+
unless response.is_a?(Net::HTTPSuccess)
|
|
865
|
+
# Try master branch if main fails
|
|
866
|
+
url = "https://raw.githubusercontent.com/#{registry}/master/#{file_path}"
|
|
867
|
+
uri = URI.parse(url)
|
|
868
|
+
response = Net::HTTP.get_response(uri)
|
|
869
|
+
end
|
|
870
|
+
|
|
871
|
+
unless response.is_a?(Net::HTTPSuccess)
|
|
872
|
+
@error << "Error: failed to fetch #{file_path} from #{registry} (#{response.code})"
|
|
873
|
+
@exit_code = 1
|
|
874
|
+
return ''
|
|
875
|
+
end
|
|
876
|
+
|
|
877
|
+
response.body
|
|
878
|
+
end
|
|
879
|
+
end
|
|
880
|
+
end
|
|
@@ -62,8 +62,6 @@ module Classifier
|
|
|
62
62
|
tolerance: DEFAULT_TOLERANCE)
|
|
63
63
|
super()
|
|
64
64
|
categories = categories.flatten
|
|
65
|
-
raise ArgumentError, 'At least two categories required' if categories.size < 2
|
|
66
|
-
|
|
67
65
|
@categories = categories.map { |c| c.to_s.prepare_category_name }
|
|
68
66
|
@weights = @categories.to_h { |c| [c, {}] }
|
|
69
67
|
@bias = @categories.to_h { |c| [c, 0.0] }
|
|
@@ -99,6 +97,7 @@ module Classifier
|
|
|
99
97
|
def fit
|
|
100
98
|
synchronize do
|
|
101
99
|
return self if @training_data.empty?
|
|
100
|
+
raise ArgumentError, 'At least two categories required for fitting' if @categories.size < 2
|
|
102
101
|
|
|
103
102
|
optimize_weights
|
|
104
103
|
@fitted = true
|
|
@@ -122,13 +121,14 @@ module Classifier
|
|
|
122
121
|
|
|
123
122
|
# Returns probability distribution across all categories.
|
|
124
123
|
# Probabilities are well-calibrated (unlike Naive Bayes).
|
|
124
|
+
# Raises NotFittedError if model has not been fitted.
|
|
125
125
|
#
|
|
126
126
|
# classifier.probabilities("Buy now!")
|
|
127
127
|
# # => {"Spam" => 0.92, "Ham" => 0.08}
|
|
128
128
|
#
|
|
129
129
|
# @rbs (String) -> Hash[String, Float]
|
|
130
130
|
def probabilities(text)
|
|
131
|
-
fit unless @fitted
|
|
131
|
+
raise NotFittedError, 'Model not fitted. Call fit() after training.' unless @fitted
|
|
132
132
|
|
|
133
133
|
features = text.word_hash
|
|
134
134
|
synchronize do
|
|
@@ -137,10 +137,11 @@ module Classifier
|
|
|
137
137
|
end
|
|
138
138
|
|
|
139
139
|
# Returns log-odds scores for each category (before softmax).
|
|
140
|
+
# Raises NotFittedError if model has not been fitted.
|
|
140
141
|
#
|
|
141
142
|
# @rbs (String) -> Hash[String, Float]
|
|
142
143
|
def classifications(text)
|
|
143
|
-
fit unless @fitted
|
|
144
|
+
raise NotFittedError, 'Model not fitted. Call fit() after training.' unless @fitted
|
|
144
145
|
|
|
145
146
|
features = text.word_hash
|
|
146
147
|
synchronize do
|
|
@@ -173,6 +174,23 @@ module Classifier
|
|
|
173
174
|
synchronize { @categories.map(&:to_s) }
|
|
174
175
|
end
|
|
175
176
|
|
|
177
|
+
# Adds a new category to the classifier.
|
|
178
|
+
# Allows dynamic category creation for CLI and incremental training.
|
|
179
|
+
#
|
|
180
|
+
# @rbs (String | Symbol) -> void
|
|
181
|
+
def add_category(category)
|
|
182
|
+
cat = category.to_s.prepare_category_name
|
|
183
|
+
synchronize do
|
|
184
|
+
return if @categories.include?(cat)
|
|
185
|
+
|
|
186
|
+
@categories << cat
|
|
187
|
+
@weights[cat] = {}
|
|
188
|
+
@bias[cat] = 0.0
|
|
189
|
+
@fitted = false
|
|
190
|
+
@dirty = true
|
|
191
|
+
end
|
|
192
|
+
end
|
|
193
|
+
|
|
176
194
|
# Returns true if the model has been fitted.
|
|
177
195
|
#
|
|
178
196
|
# @rbs () -> bool
|
|
@@ -205,11 +223,10 @@ module Classifier
|
|
|
205
223
|
end
|
|
206
224
|
|
|
207
225
|
# Returns a hash representation of the classifier state.
|
|
226
|
+
# Does NOT auto-fit; saves current state including unfitted models.
|
|
208
227
|
#
|
|
209
228
|
# @rbs (?untyped) -> Hash[Symbol, untyped]
|
|
210
229
|
def as_json(_options = nil)
|
|
211
|
-
fit unless @fitted
|
|
212
|
-
|
|
213
230
|
{
|
|
214
231
|
version: 1,
|
|
215
232
|
type: 'logistic_regression',
|
|
@@ -217,10 +234,12 @@ module Classifier
|
|
|
217
234
|
weights: @weights.transform_keys(&:to_s).transform_values { |v| v.transform_keys(&:to_s) },
|
|
218
235
|
bias: @bias.transform_keys(&:to_s),
|
|
219
236
|
vocabulary: @vocabulary.keys.map(&:to_s),
|
|
237
|
+
training_data: @training_data.map { |d| { category: d[:category].to_s, features: d[:features].transform_keys(&:to_s) } },
|
|
220
238
|
learning_rate: @learning_rate,
|
|
221
239
|
regularization: @regularization,
|
|
222
240
|
max_iterations: @max_iterations,
|
|
223
|
-
tolerance: @tolerance
|
|
241
|
+
tolerance: @tolerance,
|
|
242
|
+
fitted: @fitted
|
|
224
243
|
}
|
|
225
244
|
end
|
|
226
245
|
|
|
@@ -546,26 +565,29 @@ module Classifier
|
|
|
546
565
|
def restore_state(data, categories)
|
|
547
566
|
mu_initialize
|
|
548
567
|
@categories = categories
|
|
568
|
+
restore_weights_and_bias(data)
|
|
569
|
+
restore_hyperparameters(data)
|
|
570
|
+
@fitted = data.fetch('fitted', true)
|
|
571
|
+
@dirty = false
|
|
572
|
+
@storage = nil
|
|
573
|
+
end
|
|
574
|
+
|
|
575
|
+
def restore_weights_and_bias(data)
|
|
549
576
|
@weights = {}
|
|
550
577
|
@bias = {}
|
|
551
|
-
|
|
552
|
-
data['
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
data['bias'].each do |cat, value|
|
|
557
|
-
@bias[cat.to_sym] = value.to_f
|
|
578
|
+
data['weights'].each { |cat, words| @weights[cat.to_sym] = words.transform_keys(&:to_sym).transform_values(&:to_f) }
|
|
579
|
+
data['bias'].each { |cat, value| @bias[cat.to_sym] = value.to_f }
|
|
580
|
+
@vocabulary = data['vocabulary'].to_h { |v| [v.to_sym, true] }
|
|
581
|
+
@training_data = (data['training_data'] || []).map do |d|
|
|
582
|
+
{ category: d['category'].to_sym, features: d['features'].transform_keys(&:to_sym).transform_values(&:to_i) }
|
|
558
583
|
end
|
|
584
|
+
end
|
|
559
585
|
|
|
560
|
-
|
|
586
|
+
def restore_hyperparameters(data)
|
|
561
587
|
@learning_rate = data['learning_rate']
|
|
562
588
|
@regularization = data['regularization']
|
|
563
589
|
@max_iterations = data['max_iterations']
|
|
564
590
|
@tolerance = data['tolerance']
|
|
565
|
-
@training_data = []
|
|
566
|
-
@fitted = true
|
|
567
|
-
@dirty = false
|
|
568
|
-
@storage = nil
|
|
569
591
|
end
|
|
570
592
|
end
|
|
571
593
|
end
|
data/lib/classifier.rb
CHANGED
data/sig/classifier.rbs
ADDED
data/sig/vendor/json.rbs
CHANGED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Minimal type definitions for optparse stdlib
|
|
2
|
+
|
|
3
|
+
class OptionParser
|
|
4
|
+
class InvalidOption < StandardError
|
|
5
|
+
end
|
|
6
|
+
|
|
7
|
+
class MissingArgument < StandardError
|
|
8
|
+
end
|
|
9
|
+
|
|
10
|
+
class InvalidArgument < StandardError
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def initialize: () { (OptionParser) -> void } -> void
|
|
14
|
+
def banner=: (String) -> String
|
|
15
|
+
def separator: (String) -> void
|
|
16
|
+
def on: (*untyped) ?{ (*untyped) -> untyped } -> void
|
|
17
|
+
def to_s: () -> String
|
|
18
|
+
def parse!: (Array[String]) -> Array[String]
|
|
19
|
+
end
|
metadata
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: classifier
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 2.
|
|
4
|
+
version: 2.3.0
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Lucas Carlson
|
|
8
|
-
bindir:
|
|
8
|
+
bindir: exe
|
|
9
9
|
cert_chain: []
|
|
10
10
|
date: 1980-01-02 00:00:00.000000000 Z
|
|
11
11
|
dependencies:
|
|
@@ -121,12 +121,27 @@ dependencies:
|
|
|
121
121
|
- - ">="
|
|
122
122
|
- !ruby/object:Gem::Version
|
|
123
123
|
version: '0'
|
|
124
|
+
- !ruby/object:Gem::Dependency
|
|
125
|
+
name: webmock
|
|
126
|
+
requirement: !ruby/object:Gem::Requirement
|
|
127
|
+
requirements:
|
|
128
|
+
- - ">="
|
|
129
|
+
- !ruby/object:Gem::Version
|
|
130
|
+
version: '0'
|
|
131
|
+
type: :development
|
|
132
|
+
prerelease: false
|
|
133
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
134
|
+
requirements:
|
|
135
|
+
- - ">="
|
|
136
|
+
- !ruby/object:Gem::Version
|
|
137
|
+
version: '0'
|
|
124
138
|
description: A Ruby library for text classification featuring Naive Bayes, LSI (Latent
|
|
125
139
|
Semantic Indexing), Logistic Regression, and k-Nearest Neighbors classifiers. Includes
|
|
126
140
|
TF-IDF vectorization, streaming/incremental training, pluggable persistence backends,
|
|
127
141
|
thread safety, and a native C extension for fast LSI operations.
|
|
128
142
|
email: lucas@rufy.com
|
|
129
|
-
executables:
|
|
143
|
+
executables:
|
|
144
|
+
- classifier
|
|
130
145
|
extensions:
|
|
131
146
|
- ext/classifier/extconf.rb
|
|
132
147
|
extra_rdoc_files: []
|
|
@@ -136,6 +151,7 @@ files:
|
|
|
136
151
|
- README.md
|
|
137
152
|
- bin/bayes.rb
|
|
138
153
|
- bin/summarize.rb
|
|
154
|
+
- exe/classifier
|
|
139
155
|
- ext/classifier/classifier_ext.c
|
|
140
156
|
- ext/classifier/extconf.rb
|
|
141
157
|
- ext/classifier/incremental_svd.c
|
|
@@ -145,6 +161,7 @@ files:
|
|
|
145
161
|
- ext/classifier/vector.c
|
|
146
162
|
- lib/classifier.rb
|
|
147
163
|
- lib/classifier/bayes.rb
|
|
164
|
+
- lib/classifier/cli.rb
|
|
148
165
|
- lib/classifier/errors.rb
|
|
149
166
|
- lib/classifier/extensions/string.rb
|
|
150
167
|
- lib/classifier/extensions/vector.rb
|
|
@@ -164,11 +181,14 @@ files:
|
|
|
164
181
|
- lib/classifier/streaming/line_reader.rb
|
|
165
182
|
- lib/classifier/streaming/progress.rb
|
|
166
183
|
- lib/classifier/tfidf.rb
|
|
184
|
+
- lib/classifier/version.rb
|
|
185
|
+
- sig/classifier.rbs
|
|
167
186
|
- sig/vendor/fast_stemmer.rbs
|
|
168
187
|
- sig/vendor/gsl.rbs
|
|
169
188
|
- sig/vendor/json.rbs
|
|
170
189
|
- sig/vendor/matrix.rbs
|
|
171
190
|
- sig/vendor/mutex_m.rbs
|
|
191
|
+
- sig/vendor/optparse.rbs
|
|
172
192
|
- sig/vendor/streaming.rbs
|
|
173
193
|
- test/test_helper.rb
|
|
174
194
|
homepage: https://rubyclassifier.com
|