classifier 2.1.0 → 2.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +70 -199
- data/exe/classifier +9 -0
- data/ext/classifier/classifier_ext.c +1 -0
- data/ext/classifier/incremental_svd.c +393 -0
- data/ext/classifier/linalg.h +8 -0
- data/lib/classifier/bayes.rb +177 -53
- data/lib/classifier/cli.rb +880 -0
- data/lib/classifier/errors.rb +3 -0
- data/lib/classifier/knn.rb +351 -0
- data/lib/classifier/logistic_regression.rb +593 -0
- data/lib/classifier/lsi/incremental_svd.rb +166 -0
- data/lib/classifier/lsi/summary.rb +25 -5
- data/lib/classifier/lsi.rb +365 -17
- data/lib/classifier/streaming/line_reader.rb +99 -0
- data/lib/classifier/streaming/progress.rb +96 -0
- data/lib/classifier/streaming.rb +122 -0
- data/lib/classifier/tfidf.rb +408 -0
- data/lib/classifier/version.rb +3 -0
- data/lib/classifier.rb +5 -0
- data/sig/classifier.rbs +3 -0
- data/sig/vendor/json.rbs +1 -0
- data/sig/vendor/matrix.rbs +25 -14
- data/sig/vendor/optparse.rbs +19 -0
- data/sig/vendor/streaming.rbs +14 -0
- metadata +39 -6
|
@@ -0,0 +1,880 @@
|
|
|
1
|
+
# rbs_inline: enabled
|
|
2
|
+
|
|
3
|
+
require 'json'
|
|
4
|
+
require 'optparse'
|
|
5
|
+
require 'net/http'
|
|
6
|
+
require 'uri'
|
|
7
|
+
require 'fileutils'
|
|
8
|
+
require 'classifier'
|
|
9
|
+
|
|
10
|
+
module Classifier
|
|
11
|
+
class CLI
|
|
12
|
+
# @rbs @args: Array[String]
|
|
13
|
+
# @rbs @stdin: String?
|
|
14
|
+
# @rbs @options: Hash[Symbol, untyped]
|
|
15
|
+
# @rbs @output: Array[String]
|
|
16
|
+
# @rbs @error: Array[String]
|
|
17
|
+
# @rbs @exit_code: Integer
|
|
18
|
+
# @rbs @parser: OptionParser
|
|
19
|
+
|
|
20
|
+
CLASSIFIER_TYPES = {
|
|
21
|
+
'bayes' => :bayes,
|
|
22
|
+
'lsi' => :lsi,
|
|
23
|
+
'knn' => :knn,
|
|
24
|
+
'lr' => :logistic_regression,
|
|
25
|
+
'logistic_regression' => :logistic_regression
|
|
26
|
+
}.freeze
|
|
27
|
+
|
|
28
|
+
DEFAULT_REGISTRY = ENV.fetch('CLASSIFIER_REGISTRY', 'cardmagic/classifier-models') #: String
|
|
29
|
+
CACHE_DIR = ENV.fetch('CLASSIFIER_CACHE', File.expand_path('~/.classifier')) #: String
|
|
30
|
+
|
|
31
|
+
def initialize(args, stdin: nil)
|
|
32
|
+
@args = args.dup
|
|
33
|
+
@stdin = stdin
|
|
34
|
+
@options = {
|
|
35
|
+
model: ENV.fetch('CLASSIFIER_MODEL', './classifier.json'),
|
|
36
|
+
type: ENV.fetch('CLASSIFIER_TYPE', 'bayes'),
|
|
37
|
+
probabilities: false,
|
|
38
|
+
quiet: false,
|
|
39
|
+
count: 10,
|
|
40
|
+
k: 5,
|
|
41
|
+
weighted: false,
|
|
42
|
+
learning_rate: nil,
|
|
43
|
+
regularization: nil,
|
|
44
|
+
max_iterations: nil,
|
|
45
|
+
remote: nil,
|
|
46
|
+
output_path: nil
|
|
47
|
+
}
|
|
48
|
+
@output = [] #: Array[String]
|
|
49
|
+
@error = [] #: Array[String]
|
|
50
|
+
@exit_code = 0
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
def run
|
|
54
|
+
parse_options
|
|
55
|
+
execute_command
|
|
56
|
+
{ output: @output.join("\n"), error: @error.join("\n"), exit_code: @exit_code }
|
|
57
|
+
rescue OptionParser::InvalidOption, OptionParser::MissingArgument, OptionParser::InvalidArgument => e
|
|
58
|
+
@error << "Error: #{e.message}"
|
|
59
|
+
@exit_code = 2
|
|
60
|
+
{ output: @output.join("\n"), error: @error.join("\n"), exit_code: @exit_code }
|
|
61
|
+
rescue StandardError => e
|
|
62
|
+
@error << "Error: #{e.message}"
|
|
63
|
+
@exit_code = 1
|
|
64
|
+
{ output: @output.join("\n"), error: @error.join("\n"), exit_code: @exit_code }
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
private
|
|
68
|
+
|
|
69
|
+
def parse_options
|
|
70
|
+
@parser = OptionParser.new do |opts|
|
|
71
|
+
opts.banner = 'Usage: classifier [options] [command] [arguments]'
|
|
72
|
+
opts.separator ''
|
|
73
|
+
opts.separator 'Commands:'
|
|
74
|
+
opts.separator ' train <category> [files...] Train a category from files or stdin'
|
|
75
|
+
opts.separator ' info Show model information'
|
|
76
|
+
opts.separator ' fit Fit the model (logistic regression)'
|
|
77
|
+
opts.separator ' search <query> Semantic search (LSI only)'
|
|
78
|
+
opts.separator ' related <item> Find related documents (LSI only)'
|
|
79
|
+
opts.separator ' models [registry] List models in registry'
|
|
80
|
+
opts.separator ' pull <model> Download model from registry'
|
|
81
|
+
opts.separator ' push <file> Contribute model to registry'
|
|
82
|
+
opts.separator ' <text> Classify text (default action)'
|
|
83
|
+
opts.separator ''
|
|
84
|
+
opts.separator 'Options:'
|
|
85
|
+
|
|
86
|
+
opts.on('-f', '--file FILE', 'Model file (default: ./classifier.json)') do |file|
|
|
87
|
+
@options[:model] = file
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
opts.on('-m', '--model TYPE', 'Classifier model: bayes, lsi, knn, lr (default: bayes)') do |type|
|
|
91
|
+
unless CLASSIFIER_TYPES.key?(type)
|
|
92
|
+
raise OptionParser::InvalidArgument, "Unknown classifier model: #{type}. Valid models: #{CLASSIFIER_TYPES.keys.join(', ')}"
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
@options[:type] = type
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
opts.on('-r', '--remote MODEL', 'Use remote model: name or @user/repo:name') do |model|
|
|
99
|
+
@options[:remote] = model
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
opts.on('-o', '--output FILE', 'Output path for pull command') do |file|
|
|
103
|
+
@options[:output_path] = file
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
opts.on('-p', 'Show probabilities') do
|
|
107
|
+
@options[:probabilities] = true
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
opts.on('-n', '--count N', Integer, 'Number of results for search/related (default: 10)') do |n|
|
|
111
|
+
@options[:count] = n
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
opts.on('-k', '--neighbors N', Integer, 'Number of neighbors for KNN (default: 5)') do |n|
|
|
115
|
+
@options[:k] = n
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
opts.on('--weighted', 'Use distance-weighted voting for KNN') do
|
|
119
|
+
@options[:weighted] = true
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
opts.on('--learning-rate N', Float, 'Learning rate for logistic regression (default: 0.1)') do |n|
|
|
123
|
+
@options[:learning_rate] = n
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
opts.on('--regularization N', Float, 'L2 regularization for logistic regression (default: 0.01)') do |n|
|
|
127
|
+
@options[:regularization] = n
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
opts.on('--max-iterations N', Integer, 'Max iterations for logistic regression (default: 100)') do |n|
|
|
131
|
+
@options[:max_iterations] = n
|
|
132
|
+
end
|
|
133
|
+
|
|
134
|
+
opts.on('-q', 'Quiet mode') do
|
|
135
|
+
@options[:quiet] = true
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
opts.on('--local', 'List locally cached models (for models command)') do
|
|
139
|
+
@options[:local] = true
|
|
140
|
+
end
|
|
141
|
+
|
|
142
|
+
opts.on('-v', '--version', 'Show version') do
|
|
143
|
+
@output << Classifier::VERSION
|
|
144
|
+
@exit_code = 0
|
|
145
|
+
throw :done
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
opts.on('-h', '--help', 'Show help') do
|
|
149
|
+
@output << opts.to_s
|
|
150
|
+
@exit_code = 0
|
|
151
|
+
throw :done
|
|
152
|
+
end
|
|
153
|
+
end
|
|
154
|
+
|
|
155
|
+
catch(:done) do
|
|
156
|
+
@parser.parse!(@args)
|
|
157
|
+
end
|
|
158
|
+
end
|
|
159
|
+
|
|
160
|
+
def execute_command
|
|
161
|
+
return if @exit_code != 0 || @output.any?
|
|
162
|
+
|
|
163
|
+
command = @args.first
|
|
164
|
+
|
|
165
|
+
case command
|
|
166
|
+
when 'train'
|
|
167
|
+
command_train
|
|
168
|
+
when 'info'
|
|
169
|
+
command_info
|
|
170
|
+
when 'fit'
|
|
171
|
+
command_fit
|
|
172
|
+
when 'search'
|
|
173
|
+
command_search
|
|
174
|
+
when 'related'
|
|
175
|
+
command_related
|
|
176
|
+
when 'models'
|
|
177
|
+
command_models
|
|
178
|
+
when 'pull'
|
|
179
|
+
command_pull
|
|
180
|
+
when 'push'
|
|
181
|
+
command_push
|
|
182
|
+
else
|
|
183
|
+
command_classify
|
|
184
|
+
end
|
|
185
|
+
end
|
|
186
|
+
|
|
187
|
+
def command_train
|
|
188
|
+
@args.shift # remove 'train'
|
|
189
|
+
category = @args.shift
|
|
190
|
+
|
|
191
|
+
unless category
|
|
192
|
+
@error << 'Error: category required for train command'
|
|
193
|
+
@exit_code = 2
|
|
194
|
+
return
|
|
195
|
+
end
|
|
196
|
+
|
|
197
|
+
classifier = load_or_create_classifier
|
|
198
|
+
|
|
199
|
+
if classifier.is_a?(LSI) && @args.any?
|
|
200
|
+
train_lsi_from_files(classifier, category, @args)
|
|
201
|
+
save_classifier(classifier)
|
|
202
|
+
return
|
|
203
|
+
end
|
|
204
|
+
|
|
205
|
+
text = read_training_input
|
|
206
|
+
if text.empty?
|
|
207
|
+
@error << 'Error: no training data provided'
|
|
208
|
+
@exit_code = 2
|
|
209
|
+
return
|
|
210
|
+
end
|
|
211
|
+
|
|
212
|
+
train_classifier(classifier, category, text)
|
|
213
|
+
save_classifier(classifier)
|
|
214
|
+
end
|
|
215
|
+
|
|
216
|
+
def command_info
|
|
217
|
+
unless File.exist?(@options[:model])
|
|
218
|
+
@error << "Error: model not found at #{@options[:model]}"
|
|
219
|
+
@exit_code = 1
|
|
220
|
+
return
|
|
221
|
+
end
|
|
222
|
+
|
|
223
|
+
classifier = load_classifier
|
|
224
|
+
info = build_model_info(classifier)
|
|
225
|
+
@output << JSON.pretty_generate(info)
|
|
226
|
+
end
|
|
227
|
+
|
|
228
|
+
def build_model_info(classifier)
|
|
229
|
+
info = { file: @options[:model], type: classifier_type_name(classifier) }
|
|
230
|
+
add_common_info(info, classifier)
|
|
231
|
+
add_classifier_specific_info(info, classifier)
|
|
232
|
+
info
|
|
233
|
+
end
|
|
234
|
+
|
|
235
|
+
def add_common_info(info, classifier)
|
|
236
|
+
info[:categories] = classifier.categories.map(&:to_s) if classifier.respond_to?(:categories)
|
|
237
|
+
info[:training_count] = classifier.training_count if classifier.respond_to?(:training_count)
|
|
238
|
+
info[:vocab_size] = classifier.vocab_size if classifier.respond_to?(:vocab_size)
|
|
239
|
+
info[:fitted] = classifier.fitted? if classifier.respond_to?(:fitted?)
|
|
240
|
+
end
|
|
241
|
+
|
|
242
|
+
def add_classifier_specific_info(info, classifier)
|
|
243
|
+
case classifier
|
|
244
|
+
when Bayes then add_bayes_info(info, classifier)
|
|
245
|
+
when LSI then add_lsi_info(info, classifier)
|
|
246
|
+
when KNN then add_knn_info(info, classifier)
|
|
247
|
+
end
|
|
248
|
+
end
|
|
249
|
+
|
|
250
|
+
def add_bayes_info(info, classifier)
|
|
251
|
+
categories_data = classifier.instance_variable_get(:@categories)
|
|
252
|
+
info[:category_stats] = classifier.categories.to_h do |cat|
|
|
253
|
+
cat_data = categories_data[cat.to_sym] || {}
|
|
254
|
+
[cat.to_s, { unique_words: cat_data.size, total_words: cat_data.values.sum }]
|
|
255
|
+
end
|
|
256
|
+
end
|
|
257
|
+
|
|
258
|
+
def add_lsi_info(info, classifier)
|
|
259
|
+
info[:documents] = classifier.items.size
|
|
260
|
+
info[:items] = classifier.items
|
|
261
|
+
categories = classifier.items.map { |item| classifier.categories_for(item) }.flatten.uniq
|
|
262
|
+
info[:categories] = categories.map(&:to_s) unless categories.empty?
|
|
263
|
+
end
|
|
264
|
+
|
|
265
|
+
def add_knn_info(info, classifier)
|
|
266
|
+
data = classifier.instance_variable_get(:@data) || []
|
|
267
|
+
info[:documents] = data.size
|
|
268
|
+
categories = data.map { |d| d[:category] }.uniq
|
|
269
|
+
info[:categories] = categories.map(&:to_s) unless categories.empty?
|
|
270
|
+
end
|
|
271
|
+
|
|
272
|
+
def command_fit
|
|
273
|
+
unless File.exist?(@options[:model])
|
|
274
|
+
@error << "Error: model not found at #{@options[:model]}"
|
|
275
|
+
@exit_code = 1
|
|
276
|
+
return
|
|
277
|
+
end
|
|
278
|
+
|
|
279
|
+
classifier = load_classifier
|
|
280
|
+
|
|
281
|
+
unless classifier.respond_to?(:fit)
|
|
282
|
+
@output << 'Model does not require fitting' unless @options[:quiet]
|
|
283
|
+
return
|
|
284
|
+
end
|
|
285
|
+
|
|
286
|
+
classifier.fit
|
|
287
|
+
save_classifier(classifier)
|
|
288
|
+
@output << 'Model fitted successfully' unless @options[:quiet]
|
|
289
|
+
end
|
|
290
|
+
|
|
291
|
+
def command_search
|
|
292
|
+
@args.shift # remove 'search'
|
|
293
|
+
|
|
294
|
+
unless File.exist?(@options[:model])
|
|
295
|
+
@error << "Error: model not found at #{@options[:model]}"
|
|
296
|
+
@exit_code = 1
|
|
297
|
+
return
|
|
298
|
+
end
|
|
299
|
+
|
|
300
|
+
classifier = load_classifier
|
|
301
|
+
|
|
302
|
+
unless classifier.is_a?(LSI)
|
|
303
|
+
@error << 'Error: search requires LSI model (use -t lsi)'
|
|
304
|
+
@exit_code = 1
|
|
305
|
+
return
|
|
306
|
+
end
|
|
307
|
+
|
|
308
|
+
query = @args.join(' ')
|
|
309
|
+
query = read_stdin_line if query.empty?
|
|
310
|
+
|
|
311
|
+
if query.empty?
|
|
312
|
+
@error << 'Error: search query required'
|
|
313
|
+
@exit_code = 2
|
|
314
|
+
return
|
|
315
|
+
end
|
|
316
|
+
|
|
317
|
+
results = classifier.search(query, @options[:count])
|
|
318
|
+
results.each do |item|
|
|
319
|
+
score = classifier.proximity_norms_for_content(query).find { |i, _| i == item }&.last || 0
|
|
320
|
+
@output << "#{item}:#{format('%.2f', score)}"
|
|
321
|
+
end
|
|
322
|
+
end
|
|
323
|
+
|
|
324
|
+
def command_related
|
|
325
|
+
@args.shift # remove 'related'
|
|
326
|
+
item = @args.shift
|
|
327
|
+
|
|
328
|
+
unless item
|
|
329
|
+
@error << 'Error: item required for related command'
|
|
330
|
+
@exit_code = 2
|
|
331
|
+
return
|
|
332
|
+
end
|
|
333
|
+
|
|
334
|
+
unless File.exist?(@options[:model])
|
|
335
|
+
@error << "Error: model not found at #{@options[:model]}"
|
|
336
|
+
@exit_code = 1
|
|
337
|
+
return
|
|
338
|
+
end
|
|
339
|
+
|
|
340
|
+
classifier = load_classifier
|
|
341
|
+
|
|
342
|
+
unless classifier.is_a?(LSI)
|
|
343
|
+
@error << 'Error: related requires LSI model (use -t lsi)'
|
|
344
|
+
@exit_code = 1
|
|
345
|
+
return
|
|
346
|
+
end
|
|
347
|
+
|
|
348
|
+
unless classifier.items.include?(item)
|
|
349
|
+
@error << "Error: item not found in model: #{item}"
|
|
350
|
+
@exit_code = 1
|
|
351
|
+
return
|
|
352
|
+
end
|
|
353
|
+
|
|
354
|
+
results = classifier.find_related(item, @options[:count])
|
|
355
|
+
results.each do |related_item|
|
|
356
|
+
scores = classifier.proximity_array_for_content(item)
|
|
357
|
+
score = scores.find { |i, _| i == related_item }&.last || 0
|
|
358
|
+
@output << "#{related_item}:#{format('%.2f', score)}"
|
|
359
|
+
end
|
|
360
|
+
end
|
|
361
|
+
|
|
362
|
+
def command_models
|
|
363
|
+
@args.shift # remove 'models'
|
|
364
|
+
|
|
365
|
+
if @options[:local]
|
|
366
|
+
list_local_models
|
|
367
|
+
else
|
|
368
|
+
list_remote_models
|
|
369
|
+
end
|
|
370
|
+
end
|
|
371
|
+
|
|
372
|
+
def list_remote_models
|
|
373
|
+
registry_arg = @args.shift
|
|
374
|
+
registry = parse_registry(registry_arg) || DEFAULT_REGISTRY
|
|
375
|
+
index = fetch_registry_index(registry)
|
|
376
|
+
|
|
377
|
+
return if @exit_code != 0
|
|
378
|
+
|
|
379
|
+
if index['models'].empty?
|
|
380
|
+
@output << 'No models found in registry'
|
|
381
|
+
return
|
|
382
|
+
end
|
|
383
|
+
|
|
384
|
+
index['models'].each do |name, info|
|
|
385
|
+
type = info['type'] || 'unknown'
|
|
386
|
+
size = info['size'] || 'unknown'
|
|
387
|
+
desc = info['description'] || ''
|
|
388
|
+
@output << format('%-20<name>s %<desc>s (%<type>s, %<size>s)', name: name, desc: desc.slice(0, 40), type: type, size: size)
|
|
389
|
+
end
|
|
390
|
+
end
|
|
391
|
+
|
|
392
|
+
def list_local_models
|
|
393
|
+
models_dir = File.join(CACHE_DIR, 'models')
|
|
394
|
+
|
|
395
|
+
unless Dir.exist?(models_dir)
|
|
396
|
+
@output << 'No local models found'
|
|
397
|
+
return
|
|
398
|
+
end
|
|
399
|
+
|
|
400
|
+
# Find models from default registry
|
|
401
|
+
default_models = Dir.glob(File.join(models_dir, '*.json')).map do |path|
|
|
402
|
+
{ name: File.basename(path, '.json'), registry: nil, path: path }
|
|
403
|
+
end
|
|
404
|
+
|
|
405
|
+
# Find models from custom registries (@user/repo structure)
|
|
406
|
+
custom_models = Dir.glob(File.join(models_dir, '@*', '*', '*.json')).map do |path|
|
|
407
|
+
# Extract registry from path: .../models/@user/repo/model.json
|
|
408
|
+
repo_dir = File.dirname(path)
|
|
409
|
+
user_dir = File.dirname(repo_dir)
|
|
410
|
+
registry = "#{File.basename(user_dir).delete_prefix('@')}/#{File.basename(repo_dir)}"
|
|
411
|
+
{ name: File.basename(path, '.json'), registry: registry, path: path }
|
|
412
|
+
end
|
|
413
|
+
|
|
414
|
+
models = default_models + custom_models #: Array[{name: String, registry: String?, path: String}]
|
|
415
|
+
|
|
416
|
+
if models.empty?
|
|
417
|
+
@output << 'No local models found'
|
|
418
|
+
return
|
|
419
|
+
end
|
|
420
|
+
|
|
421
|
+
models.each do |model|
|
|
422
|
+
info = load_model_info(model[:path])
|
|
423
|
+
type = info['type'] || 'unknown'
|
|
424
|
+
display_name = model[:registry] ? "@#{model[:registry]}:#{model[:name]}" : model[:name]
|
|
425
|
+
size = File.size(model[:path])
|
|
426
|
+
@output << format('%-30<name>s (%<type>s, %<size>s)', name: display_name, type: type, size: human_size(size))
|
|
427
|
+
end
|
|
428
|
+
end
|
|
429
|
+
|
|
430
|
+
def load_model_info(path)
|
|
431
|
+
JSON.parse(File.read(path))
|
|
432
|
+
rescue JSON::ParserError
|
|
433
|
+
{}
|
|
434
|
+
end
|
|
435
|
+
|
|
436
|
+
def human_size(bytes)
|
|
437
|
+
units = %w[B KB MB GB]
|
|
438
|
+
unit_index = 0
|
|
439
|
+
size = bytes.to_f
|
|
440
|
+
|
|
441
|
+
while size >= 1024 && unit_index < units.length - 1
|
|
442
|
+
size /= 1024
|
|
443
|
+
unit_index += 1
|
|
444
|
+
end
|
|
445
|
+
|
|
446
|
+
format('%<size>.1f %<unit>s', size: size, unit: units[unit_index])
|
|
447
|
+
end
|
|
448
|
+
|
|
449
|
+
def command_pull
|
|
450
|
+
@args.shift # remove 'pull'
|
|
451
|
+
model_spec = @args.shift
|
|
452
|
+
|
|
453
|
+
unless model_spec
|
|
454
|
+
@error << 'Error: model name required for pull command'
|
|
455
|
+
@exit_code = 2
|
|
456
|
+
return
|
|
457
|
+
end
|
|
458
|
+
|
|
459
|
+
registry, model_name = parse_model_spec(model_spec)
|
|
460
|
+
registry ||= DEFAULT_REGISTRY
|
|
461
|
+
|
|
462
|
+
if model_name.nil?
|
|
463
|
+
pull_all_models(registry)
|
|
464
|
+
else
|
|
465
|
+
pull_single_model(registry, model_name)
|
|
466
|
+
end
|
|
467
|
+
end
|
|
468
|
+
|
|
469
|
+
def pull_single_model(registry, model_name)
|
|
470
|
+
index = fetch_registry_index(registry)
|
|
471
|
+
return if @exit_code != 0
|
|
472
|
+
|
|
473
|
+
model_info = index['models'][model_name]
|
|
474
|
+
unless model_info
|
|
475
|
+
@error << "Error: model '#{model_name}' not found in registry #{registry}"
|
|
476
|
+
@exit_code = 1
|
|
477
|
+
return
|
|
478
|
+
end
|
|
479
|
+
|
|
480
|
+
file_path = model_info['file'] || "models/#{model_name}.json"
|
|
481
|
+
output_path = @options[:output_path] || cache_path_for(registry, model_name)
|
|
482
|
+
|
|
483
|
+
@output << "Downloading #{model_name} from #{registry}..." unless @options[:quiet]
|
|
484
|
+
|
|
485
|
+
content = fetch_github_file(registry, file_path)
|
|
486
|
+
return if @exit_code != 0
|
|
487
|
+
|
|
488
|
+
FileUtils.mkdir_p(File.dirname(output_path))
|
|
489
|
+
File.write(output_path, content)
|
|
490
|
+
|
|
491
|
+
@output << "Saved to #{output_path}" unless @options[:quiet]
|
|
492
|
+
end
|
|
493
|
+
|
|
494
|
+
def pull_all_models(registry)
|
|
495
|
+
index = fetch_registry_index(registry)
|
|
496
|
+
return if @exit_code != 0
|
|
497
|
+
|
|
498
|
+
if index['models'].empty?
|
|
499
|
+
@output << 'No models found in registry'
|
|
500
|
+
return
|
|
501
|
+
end
|
|
502
|
+
|
|
503
|
+
@output << "Downloading #{index['models'].size} models from #{registry}..." unless @options[:quiet]
|
|
504
|
+
|
|
505
|
+
index['models'].each_key do |model_name|
|
|
506
|
+
pull_single_model(registry, model_name)
|
|
507
|
+
break if @exit_code != 0
|
|
508
|
+
end
|
|
509
|
+
end
|
|
510
|
+
|
|
511
|
+
def command_push
|
|
512
|
+
@args.shift # remove 'push'
|
|
513
|
+
|
|
514
|
+
@output << 'To contribute a model to the registry:'
|
|
515
|
+
@output << ''
|
|
516
|
+
@output << '1. Fork https://github.com/cardmagic/classifier-models'
|
|
517
|
+
@output << '2. Add your model to the models/ directory'
|
|
518
|
+
@output << '3. Update models.json with your model metadata'
|
|
519
|
+
@output << '4. Create a pull request'
|
|
520
|
+
@output << ''
|
|
521
|
+
@output << 'Or use the GitHub CLI:'
|
|
522
|
+
@output << ''
|
|
523
|
+
@output << ' gh repo fork cardmagic/classifier-models --clone'
|
|
524
|
+
@output << ' cp ./classifier.json classifier-models/models/my-model.json'
|
|
525
|
+
@output << ' # Edit classifier-models/models.json to add your model'
|
|
526
|
+
@output << ' cd classifier-models && gh pr create'
|
|
527
|
+
end
|
|
528
|
+
|
|
529
|
+
def command_classify
|
|
530
|
+
text = @args.join(' ')
|
|
531
|
+
|
|
532
|
+
if @options[:remote]
|
|
533
|
+
classify_with_remote(text)
|
|
534
|
+
return
|
|
535
|
+
end
|
|
536
|
+
|
|
537
|
+
if text.empty? && ($stdin.tty? || @stdin.nil?) && !File.exist?(@options[:model])
|
|
538
|
+
show_getting_started
|
|
539
|
+
return
|
|
540
|
+
end
|
|
541
|
+
|
|
542
|
+
unless File.exist?(@options[:model])
|
|
543
|
+
@error << "Error: model not found at #{@options[:model]}"
|
|
544
|
+
@exit_code = 1
|
|
545
|
+
return
|
|
546
|
+
end
|
|
547
|
+
|
|
548
|
+
classifier = load_classifier
|
|
549
|
+
|
|
550
|
+
if text.empty?
|
|
551
|
+
lines = read_stdin_lines
|
|
552
|
+
return show_model_usage(classifier) if lines.empty?
|
|
553
|
+
|
|
554
|
+
lines.each { |line| classify_and_output(classifier, line) }
|
|
555
|
+
else
|
|
556
|
+
classify_and_output(classifier, text)
|
|
557
|
+
end
|
|
558
|
+
end
|
|
559
|
+
|
|
560
|
+
def classify_with_remote(text)
|
|
561
|
+
registry, model_name = parse_model_spec(@options[:remote])
|
|
562
|
+
registry ||= DEFAULT_REGISTRY
|
|
563
|
+
|
|
564
|
+
unless model_name
|
|
565
|
+
@error << 'Error: model name required for -r option'
|
|
566
|
+
@exit_code = 2
|
|
567
|
+
return
|
|
568
|
+
end
|
|
569
|
+
|
|
570
|
+
cached_path = cache_path_for(registry, model_name)
|
|
571
|
+
|
|
572
|
+
unless File.exist?(cached_path)
|
|
573
|
+
pull_single_model(registry, model_name)
|
|
574
|
+
return if @exit_code != 0
|
|
575
|
+
end
|
|
576
|
+
|
|
577
|
+
original_model = @options[:model]
|
|
578
|
+
@options[:model] = cached_path
|
|
579
|
+
|
|
580
|
+
begin
|
|
581
|
+
classifier = load_classifier
|
|
582
|
+
|
|
583
|
+
if text.empty?
|
|
584
|
+
lines = read_stdin_lines
|
|
585
|
+
return show_model_usage(classifier) if lines.empty?
|
|
586
|
+
|
|
587
|
+
lines.each { |line| classify_and_output(classifier, line) }
|
|
588
|
+
else
|
|
589
|
+
classify_and_output(classifier, text)
|
|
590
|
+
end
|
|
591
|
+
ensure
|
|
592
|
+
@options[:model] = original_model
|
|
593
|
+
end
|
|
594
|
+
end
|
|
595
|
+
|
|
596
|
+
# @rbs (untyped) -> void
|
|
597
|
+
def show_model_usage(classifier)
|
|
598
|
+
type = classifier_type_name(classifier)
|
|
599
|
+
cats = classifier.categories.map(&:to_s).map(&:downcase)
|
|
600
|
+
first_cat = cats.first || 'category'
|
|
601
|
+
|
|
602
|
+
@output << "Model: #{@options[:model]} (#{type})"
|
|
603
|
+
@output << "Categories: #{cats.join(', ')}"
|
|
604
|
+
@output << ''
|
|
605
|
+
@output << 'Classify text:'
|
|
606
|
+
@output << ''
|
|
607
|
+
@output << " classifier 'text to classify'"
|
|
608
|
+
@output << " echo 'text to classify' | classifier"
|
|
609
|
+
@output << ''
|
|
610
|
+
@output << 'Train more data:'
|
|
611
|
+
@output << ''
|
|
612
|
+
@output << " echo 'new example text' | classifier train #{first_cat}"
|
|
613
|
+
@output << " classifier train #{first_cat} file1.txt file2.txt"
|
|
614
|
+
@output << ''
|
|
615
|
+
@output << 'Other commands:'
|
|
616
|
+
@output << ''
|
|
617
|
+
@output << ' classifier info Show model details (JSON)'
|
|
618
|
+
end
|
|
619
|
+
|
|
620
|
+
def classify_and_output(classifier, text)
|
|
621
|
+
return if text.strip.empty?
|
|
622
|
+
|
|
623
|
+
if classifier.is_a?(LogisticRegression) && !classifier.fitted?
|
|
624
|
+
raise StandardError, "Model not fitted. Run 'classifier fit' after training."
|
|
625
|
+
end
|
|
626
|
+
|
|
627
|
+
if @options[:probabilities]
|
|
628
|
+
probs = get_probabilities(classifier, text)
|
|
629
|
+
formatted = probs.map { |cat, prob| "#{cat.downcase}:#{format('%.2f', prob)}" }.join(' ')
|
|
630
|
+
@output << formatted
|
|
631
|
+
else
|
|
632
|
+
result = classifier.classify(text)
|
|
633
|
+
@output << result.downcase
|
|
634
|
+
end
|
|
635
|
+
end
|
|
636
|
+
|
|
637
|
+
def get_probabilities(classifier, text)
|
|
638
|
+
if classifier.respond_to?(:probabilities)
|
|
639
|
+
classifier.probabilities(text)
|
|
640
|
+
elsif classifier.respond_to?(:classifications)
|
|
641
|
+
scores = classifier.classifications(text)
|
|
642
|
+
normalize_scores(scores)
|
|
643
|
+
else
|
|
644
|
+
{ classifier.classify(text) => 1.0 }
|
|
645
|
+
end
|
|
646
|
+
end
|
|
647
|
+
|
|
648
|
+
def normalize_scores(scores)
|
|
649
|
+
max_score = scores.values.max
|
|
650
|
+
exp_scores = scores.transform_values { |s| Math.exp(s - max_score) }
|
|
651
|
+
total = exp_scores.values.sum.to_f
|
|
652
|
+
exp_scores.transform_values { |s| (s / total).to_f }
|
|
653
|
+
end
|
|
654
|
+
|
|
655
|
+
def load_or_create_classifier
|
|
656
|
+
if File.exist?(@options[:model])
|
|
657
|
+
load_classifier
|
|
658
|
+
else
|
|
659
|
+
create_classifier
|
|
660
|
+
end
|
|
661
|
+
end
|
|
662
|
+
|
|
663
|
+
def load_classifier
|
|
664
|
+
json = File.read(@options[:model])
|
|
665
|
+
data = JSON.parse(json)
|
|
666
|
+
type = data['type']
|
|
667
|
+
|
|
668
|
+
case type
|
|
669
|
+
when 'bayes'
|
|
670
|
+
Bayes.from_json(data)
|
|
671
|
+
when 'lsi'
|
|
672
|
+
LSI.from_json(data)
|
|
673
|
+
when 'knn'
|
|
674
|
+
KNN.from_json(data)
|
|
675
|
+
when 'logistic_regression'
|
|
676
|
+
LogisticRegression.from_json(data)
|
|
677
|
+
else
|
|
678
|
+
raise "Unknown classifier type in model: #{type}"
|
|
679
|
+
end
|
|
680
|
+
end
|
|
681
|
+
|
|
682
|
+
def create_classifier
|
|
683
|
+
type = CLASSIFIER_TYPES[@options[:type]] || :bayes
|
|
684
|
+
|
|
685
|
+
case type
|
|
686
|
+
when :lsi
|
|
687
|
+
LSI.new(auto_rebuild: true)
|
|
688
|
+
when :knn
|
|
689
|
+
KNN.new(k: @options[:k], weighted: @options[:weighted])
|
|
690
|
+
when :logistic_regression
|
|
691
|
+
lr_opts = {} #: Hash[Symbol, untyped]
|
|
692
|
+
lr_opts[:learning_rate] = @options[:learning_rate] if @options[:learning_rate]
|
|
693
|
+
lr_opts[:regularization] = @options[:regularization] if @options[:regularization]
|
|
694
|
+
lr_opts[:max_iterations] = @options[:max_iterations] if @options[:max_iterations]
|
|
695
|
+
LogisticRegression.new(**lr_opts)
|
|
696
|
+
else # :bayes or unknown defaults to Bayes
|
|
697
|
+
Bayes.new
|
|
698
|
+
end
|
|
699
|
+
end
|
|
700
|
+
|
|
701
|
+
def train_classifier(classifier, category, text)
|
|
702
|
+
case classifier
|
|
703
|
+
when Bayes, LogisticRegression
|
|
704
|
+
classifier.add_category(category) unless classifier.categories.include?(category)
|
|
705
|
+
text.each_line { |line| classifier.train(category, line.strip) unless line.strip.empty? }
|
|
706
|
+
when LSI
|
|
707
|
+
text.each_line do |line|
|
|
708
|
+
next if line.strip.empty?
|
|
709
|
+
|
|
710
|
+
classifier.add_item(line.strip, category.to_sym)
|
|
711
|
+
end
|
|
712
|
+
when KNN
|
|
713
|
+
text.each_line do |line|
|
|
714
|
+
next if line.strip.empty?
|
|
715
|
+
|
|
716
|
+
classifier.add(category.to_sym => line.strip)
|
|
717
|
+
end
|
|
718
|
+
end
|
|
719
|
+
end
|
|
720
|
+
|
|
721
|
+
def train_lsi_from_files(classifier, category, files)
|
|
722
|
+
files.each do |file|
|
|
723
|
+
content = File.read(file)
|
|
724
|
+
classifier.add_item(file, category.to_sym) { content }
|
|
725
|
+
end
|
|
726
|
+
end
|
|
727
|
+
|
|
728
|
+
def save_classifier(classifier)
|
|
729
|
+
classifier.storage = Storage::File.new(path: @options[:model])
|
|
730
|
+
classifier.save
|
|
731
|
+
end
|
|
732
|
+
|
|
733
|
+
def classifier_type_name(classifier)
|
|
734
|
+
case classifier
|
|
735
|
+
when Bayes then 'bayes'
|
|
736
|
+
when LSI then 'lsi'
|
|
737
|
+
when KNN then 'knn'
|
|
738
|
+
when LogisticRegression then 'logistic_regression'
|
|
739
|
+
else 'unknown'
|
|
740
|
+
end
|
|
741
|
+
end
|
|
742
|
+
|
|
743
|
+
def read_training_input
|
|
744
|
+
if @args.any?
|
|
745
|
+
@args.map { |file| File.read(file) }.join("\n")
|
|
746
|
+
else
|
|
747
|
+
read_stdin
|
|
748
|
+
end
|
|
749
|
+
end
|
|
750
|
+
|
|
751
|
+
def read_stdin
|
|
752
|
+
@stdin || ($stdin.tty? ? '' : $stdin.read)
|
|
753
|
+
end
|
|
754
|
+
|
|
755
|
+
def read_stdin_line
|
|
756
|
+
(@stdin || ($stdin.tty? ? '' : $stdin.read)).to_s.strip
|
|
757
|
+
end
|
|
758
|
+
|
|
759
|
+
def read_stdin_lines
|
|
760
|
+
read_stdin.to_s.split("\n").map(&:strip).reject(&:empty?)
|
|
761
|
+
end
|
|
762
|
+
|
|
763
|
+
# @rbs () -> void
|
|
764
|
+
def show_getting_started
|
|
765
|
+
@output << 'Classifier - Text classification from the command line'
|
|
766
|
+
@output << ''
|
|
767
|
+
@output << 'Get started by training some categories:'
|
|
768
|
+
@output << ''
|
|
769
|
+
@output << ' # Train from files'
|
|
770
|
+
@output << ' classifier train spam spam_emails/*.txt'
|
|
771
|
+
@output << ' classifier train ham good_emails/*.txt'
|
|
772
|
+
@output << ''
|
|
773
|
+
@output << ' # Train from stdin'
|
|
774
|
+
@output << " echo 'buy viagra now free pills cheap meds' | classifier train spam"
|
|
775
|
+
@output << " echo 'meeting scheduled for tomorrow to discuss project' | classifier train ham"
|
|
776
|
+
@output << ''
|
|
777
|
+
@output << 'Then classify text:'
|
|
778
|
+
@output << ''
|
|
779
|
+
@output << " classifier 'free money buy now'"
|
|
780
|
+
@output << " classifier 'meeting postponed to friday'"
|
|
781
|
+
@output << ''
|
|
782
|
+
@output << 'Use LSI for semantic search:'
|
|
783
|
+
@output << ''
|
|
784
|
+
@output << " echo 'ruby is a dynamic programming language' | classifier train docs -m lsi"
|
|
785
|
+
@output << " echo 'python is great for data science' | classifier train docs -m lsi"
|
|
786
|
+
@output << " classifier search 'programming'"
|
|
787
|
+
@output << ''
|
|
788
|
+
@output << 'Options:'
|
|
789
|
+
@output << ' -f FILE Model file (default: ./classifier.json)'
|
|
790
|
+
@output << ' -m TYPE Model type: bayes, lsi, knn, lr (default: bayes)'
|
|
791
|
+
@output << ' -r MODEL Use remote model from registry'
|
|
792
|
+
@output << ' -p Show probabilities'
|
|
793
|
+
@output << ''
|
|
794
|
+
@output << 'Use pre-trained models:'
|
|
795
|
+
@output << ''
|
|
796
|
+
@output << ' classifier models List available models'
|
|
797
|
+
@output << ' classifier pull sentiment Download a model'
|
|
798
|
+
@output << " classifier -r sentiment 'I love this!' Classify with remote model"
|
|
799
|
+
@output << ''
|
|
800
|
+
@output << 'Run "classifier --help" for full usage.'
|
|
801
|
+
end
|
|
802
|
+
|
|
803
|
+
# Parse @user/repo format to extract registry
|
|
804
|
+
# @rbs (String?) -> String?
|
|
805
|
+
def parse_registry(arg)
|
|
806
|
+
return nil unless arg
|
|
807
|
+
return nil unless arg.start_with?('@')
|
|
808
|
+
|
|
809
|
+
# @user/repo format
|
|
810
|
+
arg[1..] # Remove @ prefix
|
|
811
|
+
end
|
|
812
|
+
|
|
813
|
+
# Parse model spec: name, @user/repo:name, or @user/repo (for all models)
|
|
814
|
+
# Returns [registry, model_name] where model_name is nil if pulling all
|
|
815
|
+
# @rbs (String) -> [String?, String?]
|
|
816
|
+
def parse_model_spec(spec)
|
|
817
|
+
if spec.start_with?('@')
|
|
818
|
+
# @user/repo:model or @user/repo
|
|
819
|
+
rest = spec[1..] || ''
|
|
820
|
+
if spec.include?(':')
|
|
821
|
+
parts = rest.split(':', 2)
|
|
822
|
+
[parts[0], parts[1]]
|
|
823
|
+
else
|
|
824
|
+
# @user/repo - pull all models from registry
|
|
825
|
+
[rest, nil]
|
|
826
|
+
end
|
|
827
|
+
else
|
|
828
|
+
# Just a model name from default registry
|
|
829
|
+
[nil, spec]
|
|
830
|
+
end
|
|
831
|
+
end
|
|
832
|
+
|
|
833
|
+
# Get cache path for a model
|
|
834
|
+
# @rbs (String, String) -> String
|
|
835
|
+
def cache_path_for(registry, model_name)
|
|
836
|
+
if registry == DEFAULT_REGISTRY
|
|
837
|
+
File.join(CACHE_DIR, 'models', "#{model_name}.json")
|
|
838
|
+
else
|
|
839
|
+
File.join(CACHE_DIR, 'models', "@#{registry}", "#{model_name}.json")
|
|
840
|
+
end
|
|
841
|
+
end
|
|
842
|
+
|
|
843
|
+
# Fetch models.json index from a registry
|
|
844
|
+
# @rbs (String) -> Hash[String, untyped]
|
|
845
|
+
def fetch_registry_index(registry)
|
|
846
|
+
content = fetch_github_file(registry, 'models.json')
|
|
847
|
+
return { 'models' => {} } if @exit_code != 0
|
|
848
|
+
|
|
849
|
+
JSON.parse(content)
|
|
850
|
+
rescue JSON::ParserError => e
|
|
851
|
+
@error << "Error: invalid models.json in registry: #{e.message}"
|
|
852
|
+
@exit_code = 1
|
|
853
|
+
{ 'models' => {} }
|
|
854
|
+
end
|
|
855
|
+
|
|
856
|
+
# Fetch a file from GitHub raw content
|
|
857
|
+
# @rbs (String, String) -> String
|
|
858
|
+
def fetch_github_file(registry, file_path)
|
|
859
|
+
url = "https://raw.githubusercontent.com/#{registry}/main/#{file_path}"
|
|
860
|
+
uri = URI.parse(url)
|
|
861
|
+
|
|
862
|
+
response = Net::HTTP.get_response(uri)
|
|
863
|
+
|
|
864
|
+
unless response.is_a?(Net::HTTPSuccess)
|
|
865
|
+
# Try master branch if main fails
|
|
866
|
+
url = "https://raw.githubusercontent.com/#{registry}/master/#{file_path}"
|
|
867
|
+
uri = URI.parse(url)
|
|
868
|
+
response = Net::HTTP.get_response(uri)
|
|
869
|
+
end
|
|
870
|
+
|
|
871
|
+
unless response.is_a?(Net::HTTPSuccess)
|
|
872
|
+
@error << "Error: failed to fetch #{file_path} from #{registry} (#{response.code})"
|
|
873
|
+
@exit_code = 1
|
|
874
|
+
return ''
|
|
875
|
+
end
|
|
876
|
+
|
|
877
|
+
response.body
|
|
878
|
+
end
|
|
879
|
+
end
|
|
880
|
+
end
|