rsvm 0.0.1
Sign up to get free protection for your applications and to get access to all the features.
- data/.gitignore +6 -0
- data/Gemfile +6 -0
- data/README.md +89 -0
- data/Rakefile +41 -0
- data/ext/libsvm/COPYRIGHT +31 -0
- data/ext/libsvm/FAQ.html +1837 -0
- data/ext/libsvm/README +748 -0
- data/ext/libsvm/extconf.rb +3 -0
- data/ext/libsvm/svm.cpp +3213 -0
- data/ext/libsvm/svm.h +102 -0
- data/lib/svm.rb +119 -0
- data/lib/svm/cross_validation.rb +39 -0
- data/lib/svm/debug.rb +12 -0
- data/lib/svm/model.rb +68 -0
- data/lib/svm/options.rb +69 -0
- data/lib/svm/problem.rb +151 -0
- data/lib/svm/scaler.rb +82 -0
- data/lib/svm/version.rb +3 -0
- data/rsvm.gemspec +25 -0
- data/test/fixtures/heart_scale.csv +270 -0
- data/test/fixtures/unbalanced.csv +164 -0
- data/test/lib/cross_validation_test.rb +35 -0
- data/test/lib/model_test.rb +84 -0
- data/test/lib/problem_test.rb +39 -0
- data/test/lib/scaler_test.rb +57 -0
- data/test/test_helper.rb +3 -0
- metadata +101 -0
data/ext/libsvm/svm.h
ADDED
@@ -0,0 +1,102 @@
|
|
1
|
+
#ifndef _LIBSVM_H
|
2
|
+
#define _LIBSVM_H
|
3
|
+
|
4
|
+
#define LIBSVM_VERSION 312
|
5
|
+
|
6
|
+
#ifdef __cplusplus
|
7
|
+
extern "C" {
|
8
|
+
#endif
|
9
|
+
|
10
|
+
extern int libsvm_version;
|
11
|
+
|
12
|
+
struct svm_node
|
13
|
+
{
|
14
|
+
int index;
|
15
|
+
double value;
|
16
|
+
};
|
17
|
+
|
18
|
+
struct svm_problem
|
19
|
+
{
|
20
|
+
int l;
|
21
|
+
double *y;
|
22
|
+
struct svm_node **x;
|
23
|
+
double *W; /* instance weight */
|
24
|
+
};
|
25
|
+
|
26
|
+
enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR }; /* svm_type */
|
27
|
+
enum { LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED }; /* kernel_type */
|
28
|
+
|
29
|
+
struct svm_parameter
|
30
|
+
{
|
31
|
+
int svm_type;
|
32
|
+
int kernel_type;
|
33
|
+
int degree; /* for poly */
|
34
|
+
double gamma; /* for poly/rbf/sigmoid */
|
35
|
+
double coef0; /* for poly/sigmoid */
|
36
|
+
|
37
|
+
/* these are for training only */
|
38
|
+
double cache_size; /* in MB */
|
39
|
+
double eps; /* stopping criteria */
|
40
|
+
double C; /* for C_SVC, EPSILON_SVR and NU_SVR */
|
41
|
+
int nr_weight; /* for C_SVC */
|
42
|
+
int *weight_label; /* for C_SVC */
|
43
|
+
double* weight; /* for C_SVC */
|
44
|
+
double nu; /* for NU_SVC, ONE_CLASS, and NU_SVR */
|
45
|
+
double p; /* for EPSILON_SVR */
|
46
|
+
int shrinking; /* use the shrinking heuristics */
|
47
|
+
int probability; /* do probability estimates */
|
48
|
+
};
|
49
|
+
|
50
|
+
//
|
51
|
+
// svm_model
|
52
|
+
//
|
53
|
+
struct svm_model
|
54
|
+
{
|
55
|
+
struct svm_parameter param; /* parameter */
|
56
|
+
int nr_class; /* number of classes, = 2 in regression/one class svm */
|
57
|
+
int l; /* total #SV */
|
58
|
+
struct svm_node **SV; /* SVs (SV[l]) */
|
59
|
+
double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */
|
60
|
+
double *rho; /* constants in decision functions (rho[k*(k-1)/2]) */
|
61
|
+
double *probA; /* pariwise probability information */
|
62
|
+
double *probB;
|
63
|
+
|
64
|
+
/* for classification only */
|
65
|
+
|
66
|
+
int *label; /* label of each class (label[k]) */
|
67
|
+
int *nSV; /* number of SVs for each class (nSV[k]) */
|
68
|
+
/* nSV[0] + nSV[1] + ... + nSV[k-1] = l */
|
69
|
+
/* XXX */
|
70
|
+
int free_sv; /* 1 if svm_model is created by svm_load_model*/
|
71
|
+
/* 0 if svm_model is created by svm_train */
|
72
|
+
};
|
73
|
+
|
74
|
+
struct svm_model *svm_train(const struct svm_problem *prob, const struct svm_parameter *param);
|
75
|
+
void svm_cross_validation(const struct svm_problem *prob, const struct svm_parameter *param, int nr_fold, double *target);
|
76
|
+
|
77
|
+
int svm_save_model(const char *model_file_name, const struct svm_model *model);
|
78
|
+
struct svm_model *svm_load_model(const char *model_file_name);
|
79
|
+
|
80
|
+
int svm_get_svm_type(const struct svm_model *model);
|
81
|
+
int svm_get_nr_class(const struct svm_model *model);
|
82
|
+
void svm_get_labels(const struct svm_model *model, int *label);
|
83
|
+
double svm_get_svr_probability(const struct svm_model *model);
|
84
|
+
|
85
|
+
double svm_predict_values(const struct svm_model *model, const struct svm_node *x, double* dec_values);
|
86
|
+
double svm_predict(const struct svm_model *model, const struct svm_node *x);
|
87
|
+
double svm_predict_probability(const struct svm_model *model, const struct svm_node *x, double* prob_estimates);
|
88
|
+
|
89
|
+
void svm_free_model_content(struct svm_model *model_ptr);
|
90
|
+
void svm_free_and_destroy_model(struct svm_model **model_ptr_ptr);
|
91
|
+
void svm_destroy_param(struct svm_parameter *param);
|
92
|
+
|
93
|
+
const char *svm_check_parameter(const struct svm_problem *prob, const struct svm_parameter *param);
|
94
|
+
int svm_check_probability_model(const struct svm_model *model);
|
95
|
+
|
96
|
+
void svm_set_print_string_function(void (*print_func)(const char *));
|
97
|
+
|
98
|
+
#ifdef __cplusplus
|
99
|
+
}
|
100
|
+
#endif
|
101
|
+
|
102
|
+
#endif /* _LIBSVM_H */
|
data/lib/svm.rb
ADDED
@@ -0,0 +1,119 @@
|
|
1
|
+
require "svm/version"
|
2
|
+
require 'ffi'
|
3
|
+
|
4
|
+
require_relative 'svm/debug'
|
5
|
+
|
6
|
+
module Svm
|
7
|
+
extend FFI::Library
|
8
|
+
extend Svm::Debug
|
9
|
+
|
10
|
+
ffi_lib File.join(File.dirname(__FILE__), "libsvm/libsvm.#{RbConfig::CONFIG['DLEXT']}")
|
11
|
+
|
12
|
+
enum :svm_type, [:c_svc, :nu_svc, :one_class, :epsilon_svr, :nu_svr]
|
13
|
+
enum :kernel_type, [:linear, :poly, :rbf, :sigmoid, :precomputed]
|
14
|
+
|
15
|
+
class NodeStruct < FFI::Struct
|
16
|
+
layout :index, :int,
|
17
|
+
:value, :double
|
18
|
+
|
19
|
+
def self.node_array_from(sample_xs)
|
20
|
+
num_features = sample_xs.size
|
21
|
+
|
22
|
+
nodes_ptr = FFI::MemoryPointer.new(NodeStruct, num_features + 1)
|
23
|
+
|
24
|
+
num_features.times.each do |j|
|
25
|
+
node = NodeStruct.new(nodes_ptr + j * NodeStruct.size)
|
26
|
+
node[:index] = j
|
27
|
+
node[:value] = sample_xs[j].to_f
|
28
|
+
end
|
29
|
+
|
30
|
+
# Last node is a terminator. See libsvm README.
|
31
|
+
node = NodeStruct.new(nodes_ptr + num_features * NodeStruct.size)
|
32
|
+
node[:index] = -1
|
33
|
+
node[:value] = 0
|
34
|
+
|
35
|
+
nodes_ptr
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
class ProblemStruct < FFI::Struct
|
40
|
+
layout :l, :int,
|
41
|
+
:y, :pointer,
|
42
|
+
:svm_node, :pointer,
|
43
|
+
:W, :pointer
|
44
|
+
end
|
45
|
+
|
46
|
+
class ParameterStruct < FFI::Struct
|
47
|
+
layout :svm_type, :svm_type,
|
48
|
+
:kernel_type, :kernel_type,
|
49
|
+
:degree, :int,
|
50
|
+
:gamma, :double,
|
51
|
+
:coef0, :double,
|
52
|
+
:cache_size, :double,
|
53
|
+
:eps, :double,
|
54
|
+
:c, :double,
|
55
|
+
:nr_weight, :int,
|
56
|
+
:weight_label, :pointer,
|
57
|
+
:weight, :pointer,
|
58
|
+
:nu, :double,
|
59
|
+
:p, :double,
|
60
|
+
:shrinking, :int,
|
61
|
+
:probability, :int
|
62
|
+
end
|
63
|
+
|
64
|
+
class ModelStruct < FFI::ManagedStruct
|
65
|
+
layout :param, ParameterStruct,
|
66
|
+
:nr_class, :int,
|
67
|
+
:l, :int,
|
68
|
+
:svm_node, :pointer,
|
69
|
+
:sv_coef, :pointer,
|
70
|
+
:rho, :pointer,
|
71
|
+
:probA, :pointer,
|
72
|
+
:probB, :pointer,
|
73
|
+
:label, :pointer,
|
74
|
+
:nSV, :pointer,
|
75
|
+
:free_sv, :int
|
76
|
+
|
77
|
+
def self.release(ptr)
|
78
|
+
Svm.svm_free_model_content(ptr)
|
79
|
+
end
|
80
|
+
end
|
81
|
+
|
82
|
+
attach_function 'svm_train', [:pointer, :pointer], :pointer
|
83
|
+
|
84
|
+
attach_function 'svm_cross_validation', [:pointer, :pointer, :int, :pointer], :void
|
85
|
+
attach_function 'svm_save_model', [:string, :pointer], :int
|
86
|
+
attach_function 'svm_load_model', [:string], :pointer
|
87
|
+
attach_function 'svm_get_svm_type', [:pointer], :int
|
88
|
+
attach_function 'svm_get_nr_class', [ :pointer], :int
|
89
|
+
attach_function 'svm_get_labels', [:pointer, :pointer], :void
|
90
|
+
attach_function 'svm_get_svr_probability', [:pointer], :double
|
91
|
+
|
92
|
+
attach_function 'svm_predict_values', [:pointer, :pointer, :pointer], :double
|
93
|
+
attach_function 'svm_predict', [:pointer, :pointer], :double
|
94
|
+
attach_function 'svm_predict_probability', [:pointer, :pointer, :pointer], :double
|
95
|
+
|
96
|
+
attach_function 'svm_free_model_content', [:pointer], :void
|
97
|
+
attach_function 'svm_free_and_destroy_model', [:pointer], :void
|
98
|
+
attach_function 'svm_destroy_param', [:pointer], :void
|
99
|
+
|
100
|
+
attach_function 'svm_check_parameter', [:pointer, :pointer ], :string
|
101
|
+
attach_function 'svm_check_probability_model', [:pointer,], :int
|
102
|
+
attach_function 'svm_set_print_string_function', [:pointer,], :void
|
103
|
+
|
104
|
+
|
105
|
+
DebugCallback = FFI::Function.new(:void, [:string]) do |message|
|
106
|
+
print message if Svm.debug
|
107
|
+
end
|
108
|
+
|
109
|
+
Svm.svm_set_print_string_function(DebugCallback)
|
110
|
+
Svm.debug = false
|
111
|
+
end
|
112
|
+
|
113
|
+
|
114
|
+
|
115
|
+
require_relative 'svm/cross_validation'
|
116
|
+
require_relative 'svm/options'
|
117
|
+
require_relative 'svm/problem'
|
118
|
+
require_relative 'svm/model'
|
119
|
+
require_relative 'svm/scaler'
|
@@ -0,0 +1,39 @@
|
|
1
|
+
module Svm
|
2
|
+
module CrossValidation
|
3
|
+
|
4
|
+
def results_for_cross_validation(n_folds = 5, custom_options = nil)
|
5
|
+
results = cross_validate(n_folds, custom_options)
|
6
|
+
|
7
|
+
num_samples.times.collect { |i| value(i) == results[i] ? weight_for(i) : 0 }.inject(:+)
|
8
|
+
end
|
9
|
+
|
10
|
+
def cross_validate(n_folds = 5, more_options = nil)
|
11
|
+
set(more_options) if more_options
|
12
|
+
|
13
|
+
predicted_results_pointer = FFI::MemoryPointer.new(:double, num_samples)
|
14
|
+
|
15
|
+
Svm.svm_cross_validation(problem_struct, options.parameter_struct, n_folds, predicted_results_pointer)
|
16
|
+
|
17
|
+
predicted_results_pointer.read_array_of_double(num_samples)
|
18
|
+
end
|
19
|
+
|
20
|
+
def find_best_parameters(n_folds = 5)
|
21
|
+
c_exponents = (-1..14).to_a
|
22
|
+
gamma_exponents = (-13..-1).to_a
|
23
|
+
|
24
|
+
combinations = c_exponents.product(gamma_exponents)
|
25
|
+
|
26
|
+
max = combinations.max_by do |comb|
|
27
|
+
c = 2 ** comb[0]
|
28
|
+
gamma = 2 ** comb[1]
|
29
|
+
|
30
|
+
results_for_cross_validation(n_folds, :c => c, :gamma => gamma)
|
31
|
+
end
|
32
|
+
|
33
|
+
c = 2**max[0]
|
34
|
+
gamma = 2**max[1]
|
35
|
+
|
36
|
+
{:c => c, :gamma => gamma}
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
data/lib/svm/debug.rb
ADDED
data/lib/svm/model.rb
ADDED
@@ -0,0 +1,68 @@
|
|
1
|
+
module Svm
|
2
|
+
class ModelSerializationError < StandardError; end
|
3
|
+
class ModelError < StandardError; end
|
4
|
+
|
5
|
+
class Model
|
6
|
+
attr_reader :model_struct
|
7
|
+
attr_accessor :scaler
|
8
|
+
|
9
|
+
def initialize(model_struct)
|
10
|
+
@model_struct = model_struct
|
11
|
+
end
|
12
|
+
|
13
|
+
def save(path)
|
14
|
+
result = Svm.svm_save_model(path, model_struct.pointer)
|
15
|
+
raise ModelSerializationError.new("Unable to save model to file. Error: #{result}") unless result == 0
|
16
|
+
end
|
17
|
+
|
18
|
+
def self.load(path)
|
19
|
+
model_struct_pointer = Svm.svm_load_model(path)
|
20
|
+
raise ModelSerializationError.new("Unable to load model from file. Error: #{result}") unless model_struct_pointer != FFI::Pointer::NULL
|
21
|
+
|
22
|
+
model_struct = ModelStruct.new(model_struct_pointer)
|
23
|
+
self.new(model_struct)
|
24
|
+
end
|
25
|
+
|
26
|
+
def number_of_classes
|
27
|
+
Svm.svm_get_nr_class(model_struct)
|
28
|
+
end
|
29
|
+
|
30
|
+
def labels
|
31
|
+
labels_array = FFI::MemoryPointer.new(:int, number_of_classes)
|
32
|
+
|
33
|
+
Svm.svm_get_labels(model_struct, labels_array)
|
34
|
+
|
35
|
+
labels_array.read_array_of_int(number_of_classes)
|
36
|
+
end
|
37
|
+
|
38
|
+
def predict(sample)
|
39
|
+
scaler.scale(sample) if scaler
|
40
|
+
|
41
|
+
nodes_ptr = NodeStruct.node_array_from(sample)
|
42
|
+
Svm.svm_predict(model_struct, nodes_ptr)
|
43
|
+
end
|
44
|
+
|
45
|
+
def predict_probabilities(sample)
|
46
|
+
unless Svm.svm_check_probability_model(model_struct) == 1
|
47
|
+
raise ModelError.new("Model doesn't have probability info")
|
48
|
+
end
|
49
|
+
|
50
|
+
scaler.scale(sample) if scaler
|
51
|
+
|
52
|
+
nodes_ptr = NodeStruct.node_array_from(sample)
|
53
|
+
|
54
|
+
prob_array = FFI::MemoryPointer.new(:double, number_of_classes)
|
55
|
+
|
56
|
+
Svm.svm_predict_probability(model_struct, nodes_ptr, prob_array)
|
57
|
+
probabilities = prob_array.read_array_of_double(number_of_classes)
|
58
|
+
|
59
|
+
number_of_classes.times.inject({}) do |hash, index|
|
60
|
+
label = labels[index]
|
61
|
+
prob = probabilities[index]
|
62
|
+
|
63
|
+
hash[label] = prob
|
64
|
+
hash
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
data/lib/svm/options.rb
ADDED
@@ -0,0 +1,69 @@
|
|
1
|
+
module Svm
|
2
|
+
class ParameterError < StandardError; end
|
3
|
+
|
4
|
+
class Options
|
5
|
+
attr_reader :parameter_struct
|
6
|
+
|
7
|
+
DEFAULT_OPTIONS = {
|
8
|
+
:svm_type => :c_svc,
|
9
|
+
:kernel_type => :rbf,
|
10
|
+
:degree => 3,
|
11
|
+
:gamma => 0,
|
12
|
+
:coef0 => 0,
|
13
|
+
:nu => 0.5,
|
14
|
+
:cache_size => 100.0,
|
15
|
+
:c => 1,
|
16
|
+
:eps => 0.001,
|
17
|
+
:p => 0.1,
|
18
|
+
:shrinking => 1,
|
19
|
+
:probability => 0,
|
20
|
+
:nr_weight => 0,
|
21
|
+
:cross_validation => 0,
|
22
|
+
:nr_fold => 0,
|
23
|
+
:scale => true
|
24
|
+
}
|
25
|
+
|
26
|
+
def initialize(user_options = {})
|
27
|
+
@parameter_struct = ParameterStruct.new
|
28
|
+
add(DEFAULT_OPTIONS.merge(user_options))
|
29
|
+
end
|
30
|
+
|
31
|
+
def add(more_options)
|
32
|
+
options_hash.merge!(more_options)
|
33
|
+
|
34
|
+
more_options.each do |key, value|
|
35
|
+
parameter_struct[key] = value if parameter_struct.members.include?(key)
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
def label_weights=(weights)
|
40
|
+
@weights = weights
|
41
|
+
|
42
|
+
num_labels = weights.keys.size
|
43
|
+
|
44
|
+
parameter_struct[:nr_weight] = num_labels
|
45
|
+
|
46
|
+
parameter_struct[:weight_label] = FFI::MemoryPointer.new(:int, num_labels)
|
47
|
+
parameter_struct[:weight] = FFI::MemoryPointer.new(:double, num_labels)
|
48
|
+
|
49
|
+
labels_array = weights.keys.collect(&:to_i)
|
50
|
+
|
51
|
+
parameter_struct[:weight_label].write_array_of_int(labels_array)
|
52
|
+
parameter_struct[:weight].write_array_of_double(weights.values)
|
53
|
+
end
|
54
|
+
|
55
|
+
def weights
|
56
|
+
@weights ||= Hash.new(1.0)
|
57
|
+
end
|
58
|
+
|
59
|
+
def [](option)
|
60
|
+
options_hash[option]
|
61
|
+
end
|
62
|
+
|
63
|
+
private
|
64
|
+
|
65
|
+
def options_hash
|
66
|
+
@options_hash ||= {}
|
67
|
+
end
|
68
|
+
end
|
69
|
+
end
|
data/lib/svm/problem.rb
ADDED
@@ -0,0 +1,151 @@
|
|
1
|
+
require 'csv'
|
2
|
+
|
3
|
+
module Svm
|
4
|
+
class Problem
|
5
|
+
include CrossValidation
|
6
|
+
|
7
|
+
attr_reader :num_samples
|
8
|
+
attr_reader :num_features
|
9
|
+
attr_reader :options
|
10
|
+
|
11
|
+
attr_accessor :scaler
|
12
|
+
|
13
|
+
def self.load_from_csv(csv_path, options = {})
|
14
|
+
data = CSV.read(csv_path).collect do |row|
|
15
|
+
row.collect { |field| field.to_f }
|
16
|
+
end
|
17
|
+
|
18
|
+
instance = self.new(options)
|
19
|
+
instance.data = data
|
20
|
+
|
21
|
+
instance
|
22
|
+
end
|
23
|
+
|
24
|
+
def initialize(user_options = {})
|
25
|
+
@nodes_pointers = []
|
26
|
+
@options = Options.new(user_options)
|
27
|
+
end
|
28
|
+
|
29
|
+
def data=(samples, weights = nil)
|
30
|
+
@num_samples = samples.size
|
31
|
+
@num_features = samples.first.size - 1
|
32
|
+
@sample_weights = weights if weights
|
33
|
+
|
34
|
+
if options[:scale]
|
35
|
+
self.scaler = Scaler.scale(samples)
|
36
|
+
scaler.release_data!
|
37
|
+
end
|
38
|
+
|
39
|
+
problem_struct[:l] = num_samples
|
40
|
+
problem_struct[:svm_node] = FFI::MemoryPointer.new(FFI::Pointer, num_samples)
|
41
|
+
problem_struct[:y] = FFI::MemoryPointer.new(FFI::Type::DOUBLE, num_samples)
|
42
|
+
problem_struct[:W] = FFI::MemoryPointer.new(FFI::Type::DOUBLE, num_samples)
|
43
|
+
|
44
|
+
# Allocate memory for the samples
|
45
|
+
# There are num_samples each with num_features nodes
|
46
|
+
|
47
|
+
num_samples.times.each do |i|
|
48
|
+
sample = samples[i].collect(&:to_f)
|
49
|
+
|
50
|
+
sample_value = sample.first
|
51
|
+
sample_xs = sample[1..sample.size-1]
|
52
|
+
sample_weight = sample_weights[i]
|
53
|
+
|
54
|
+
problem_struct[:y].put_double(FFI::Type::DOUBLE.size * i, sample_value)
|
55
|
+
problem_struct[:W].put_double(FFI::Type::DOUBLE.size * i, sample_weight)
|
56
|
+
|
57
|
+
# Allocate memory for the sample
|
58
|
+
nodes_ptr = NodeStruct.node_array_from(sample_xs)
|
59
|
+
problem_struct[:svm_node].put_pointer(FFI::Pointer.size*i, nodes_ptr)
|
60
|
+
|
61
|
+
# We have to keep a reference to the pointer so it is not gargabe collected
|
62
|
+
@nodes_pointers << nodes_ptr
|
63
|
+
end
|
64
|
+
end
|
65
|
+
|
66
|
+
def sample(index)
|
67
|
+
sample_ptr = @nodes_pointers[index]
|
68
|
+
|
69
|
+
num_features.times.collect do |j|
|
70
|
+
node = NodeStruct.new(sample_ptr + NodeStruct.size * j)
|
71
|
+
node[:value]
|
72
|
+
end
|
73
|
+
end
|
74
|
+
|
75
|
+
def value(index)
|
76
|
+
problem_struct[:y].get_double(FFI::Type::DOUBLE.size * index)
|
77
|
+
end
|
78
|
+
|
79
|
+
def length
|
80
|
+
problem_struct[:l]
|
81
|
+
end
|
82
|
+
|
83
|
+
def generate_model(more_options = {})
|
84
|
+
set(more_options)
|
85
|
+
|
86
|
+
model_pointer = Svm.svm_train(problem_struct.pointer, options.parameter_struct.pointer)
|
87
|
+
model_struct = ModelStruct.new(model_pointer)
|
88
|
+
|
89
|
+
model = Model.new(model_struct)
|
90
|
+
model.scaler = scaler
|
91
|
+
|
92
|
+
model
|
93
|
+
end
|
94
|
+
|
95
|
+
def suggested_labels_weights
|
96
|
+
labels.inject({}) do |hash, label|
|
97
|
+
num = num_samples_for(label).to_f
|
98
|
+
hash[label.to_i] = num/num_samples
|
99
|
+
hash
|
100
|
+
end
|
101
|
+
end
|
102
|
+
|
103
|
+
def num_samples_for(label)
|
104
|
+
num_samples.times.count { |i| value(i) == label }
|
105
|
+
end
|
106
|
+
|
107
|
+
def labels
|
108
|
+
num_samples.times.collect { |i| value(i) }.uniq
|
109
|
+
end
|
110
|
+
|
111
|
+
def label_weights=(weights)
|
112
|
+
options.label_weights = weights
|
113
|
+
check_parameters!
|
114
|
+
end
|
115
|
+
|
116
|
+
def weight_for(i)
|
117
|
+
sample_weights[i] || 1.0
|
118
|
+
end
|
119
|
+
|
120
|
+
def sample_weights=(weights)
|
121
|
+
@sample_weights = weights
|
122
|
+
end
|
123
|
+
|
124
|
+
def sample_weights
|
125
|
+
@sample_weights ||= Array.new(num_samples, 1.0)
|
126
|
+
end
|
127
|
+
|
128
|
+
def estimate_probabilities=(option)
|
129
|
+
value = option ? 1 : 0
|
130
|
+
|
131
|
+
options.parameter_struct[:probability] = value
|
132
|
+
end
|
133
|
+
|
134
|
+
def set(custom_options)
|
135
|
+
options.add(custom_options)
|
136
|
+
check_parameters!
|
137
|
+
end
|
138
|
+
|
139
|
+
private
|
140
|
+
|
141
|
+
def problem_struct
|
142
|
+
@problem_struct ||= ProblemStruct.new
|
143
|
+
end
|
144
|
+
|
145
|
+
def check_parameters!
|
146
|
+
error = Svm.svm_check_parameter(problem_struct, options.parameter_struct)
|
147
|
+
raise ParameterError.new("The provided options are not valid: #{error}") if error
|
148
|
+
end
|
149
|
+
|
150
|
+
end
|
151
|
+
end
|