ruby_linear 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,235 @@
1
+ #include <math.h>
2
+ #include <stdio.h>
3
+ #include <string.h>
4
+ #include <stdarg.h>
5
+ #include "tron.h"
6
+
7
+ #ifndef min
8
+ template <class T> static inline T min(T x,T y) { return (x<y)?x:y; }
9
+ #endif
10
+
11
+ #ifndef max
12
+ template <class T> static inline T max(T x,T y) { return (x>y)?x:y; }
13
+ #endif
14
+
15
+ #ifdef __cplusplus
16
+ extern "C" {
17
+ #endif
18
+
19
+ extern double dnrm2_(int *, double *, int *);
20
+ extern double ddot_(int *, double *, int *, double *, int *);
21
+ extern int daxpy_(int *, double *, double *, int *, double *, int *);
22
+ extern int dscal_(int *, double *, double *, int *);
23
+
24
+ #ifdef __cplusplus
25
+ }
26
+ #endif
27
+
28
+ static void default_print(const char *buf)
29
+ {
30
+ fputs(buf,stdout);
31
+ fflush(stdout);
32
+ }
33
+
34
+ void TRON::info(const char *fmt,...)
35
+ {
36
+ char buf[BUFSIZ];
37
+ va_list ap;
38
+ va_start(ap,fmt);
39
+ vsprintf(buf,fmt,ap);
40
+ va_end(ap);
41
+ (*tron_print_string)(buf);
42
+ }
43
+
44
+ TRON::TRON(const function *fun_obj, double eps, int max_iter)
45
+ {
46
+ this->fun_obj=const_cast<function *>(fun_obj);
47
+ this->eps=eps;
48
+ this->max_iter=max_iter;
49
+ tron_print_string = default_print;
50
+ }
51
+
52
+ TRON::~TRON()
53
+ {
54
+ }
55
+
56
+ void TRON::tron(double *w)
57
+ {
58
+ // Parameters for updating the iterates.
59
+ double eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75;
60
+
61
+ // Parameters for updating the trust region size delta.
62
+ double sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4;
63
+
64
+ int n = fun_obj->get_nr_variable();
65
+ int i, cg_iter;
66
+ double delta, snorm, one=1.0;
67
+ double alpha, f, fnew, prered, actred, gs;
68
+ int search = 1, iter = 1, inc = 1;
69
+ double *s = new double[n];
70
+ double *r = new double[n];
71
+ double *w_new = new double[n];
72
+ double *g = new double[n];
73
+
74
+ for (i=0; i<n; i++)
75
+ w[i] = 0;
76
+
77
+ f = fun_obj->fun(w);
78
+ fun_obj->grad(w, g);
79
+ delta = dnrm2_(&n, g, &inc);
80
+ double gnorm1 = delta;
81
+ double gnorm = gnorm1;
82
+
83
+ if (gnorm <= eps*gnorm1)
84
+ search = 0;
85
+
86
+ iter = 1;
87
+
88
+ while (iter <= max_iter && search)
89
+ {
90
+ cg_iter = trcg(delta, g, s, r);
91
+
92
+ memcpy(w_new, w, sizeof(double)*n);
93
+ daxpy_(&n, &one, s, &inc, w_new, &inc);
94
+
95
+ gs = ddot_(&n, g, &inc, s, &inc);
96
+ prered = -0.5*(gs-ddot_(&n, s, &inc, r, &inc));
97
+ fnew = fun_obj->fun(w_new);
98
+
99
+ // Compute the actual reduction.
100
+ actred = f - fnew;
101
+
102
+ // On the first iteration, adjust the initial step bound.
103
+ snorm = dnrm2_(&n, s, &inc);
104
+ if (iter == 1)
105
+ delta = min(delta, snorm);
106
+
107
+ // Compute prediction alpha*snorm of the step.
108
+ if (fnew - f - gs <= 0)
109
+ alpha = sigma3;
110
+ else
111
+ alpha = max(sigma1, -0.5*(gs/(fnew - f - gs)));
112
+
113
+ // Update the trust region bound according to the ratio of actual to predicted reduction.
114
+ if (actred < eta0*prered)
115
+ delta = min(max(alpha, sigma1)*snorm, sigma2*delta);
116
+ else if (actred < eta1*prered)
117
+ delta = max(sigma1*delta, min(alpha*snorm, sigma2*delta));
118
+ else if (actred < eta2*prered)
119
+ delta = max(sigma1*delta, min(alpha*snorm, sigma3*delta));
120
+ else
121
+ delta = max(delta, min(alpha*snorm, sigma3*delta));
122
+
123
+ info("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d\n", iter, actred, prered, delta, f, gnorm, cg_iter);
124
+
125
+ if (actred > eta0*prered)
126
+ {
127
+ iter++;
128
+ memcpy(w, w_new, sizeof(double)*n);
129
+ f = fnew;
130
+ fun_obj->grad(w, g);
131
+
132
+ gnorm = dnrm2_(&n, g, &inc);
133
+ if (gnorm <= eps*gnorm1)
134
+ break;
135
+ }
136
+ if (f < -1.0e+32)
137
+ {
138
+ info("warning: f < -1.0e+32\n");
139
+ break;
140
+ }
141
+ if (fabs(actred) <= 0 && prered <= 0)
142
+ {
143
+ info("warning: actred and prered <= 0\n");
144
+ break;
145
+ }
146
+ if (fabs(actred) <= 1.0e-12*fabs(f) &&
147
+ fabs(prered) <= 1.0e-12*fabs(f))
148
+ {
149
+ info("warning: actred and prered too small\n");
150
+ break;
151
+ }
152
+ }
153
+
154
+ delete[] g;
155
+ delete[] r;
156
+ delete[] w_new;
157
+ delete[] s;
158
+ }
159
+
160
+ int TRON::trcg(double delta, double *g, double *s, double *r)
161
+ {
162
+ int i, inc = 1;
163
+ int n = fun_obj->get_nr_variable();
164
+ double one = 1;
165
+ double *d = new double[n];
166
+ double *Hd = new double[n];
167
+ double rTr, rnewTrnew, alpha, beta, cgtol;
168
+
169
+ for (i=0; i<n; i++)
170
+ {
171
+ s[i] = 0;
172
+ r[i] = -g[i];
173
+ d[i] = r[i];
174
+ }
175
+ cgtol = 0.1*dnrm2_(&n, g, &inc);
176
+
177
+ int cg_iter = 0;
178
+ rTr = ddot_(&n, r, &inc, r, &inc);
179
+ while (1)
180
+ {
181
+ if (dnrm2_(&n, r, &inc) <= cgtol)
182
+ break;
183
+ cg_iter++;
184
+ fun_obj->Hv(d, Hd);
185
+
186
+ alpha = rTr/ddot_(&n, d, &inc, Hd, &inc);
187
+ daxpy_(&n, &alpha, d, &inc, s, &inc);
188
+ if (dnrm2_(&n, s, &inc) > delta)
189
+ {
190
+ info("cg reaches trust region boundary\n");
191
+ alpha = -alpha;
192
+ daxpy_(&n, &alpha, d, &inc, s, &inc);
193
+
194
+ double std = ddot_(&n, s, &inc, d, &inc);
195
+ double sts = ddot_(&n, s, &inc, s, &inc);
196
+ double dtd = ddot_(&n, d, &inc, d, &inc);
197
+ double dsq = delta*delta;
198
+ double rad = sqrt(std*std + dtd*(dsq-sts));
199
+ if (std >= 0)
200
+ alpha = (dsq - sts)/(std + rad);
201
+ else
202
+ alpha = (rad - std)/dtd;
203
+ daxpy_(&n, &alpha, d, &inc, s, &inc);
204
+ alpha = -alpha;
205
+ daxpy_(&n, &alpha, Hd, &inc, r, &inc);
206
+ break;
207
+ }
208
+ alpha = -alpha;
209
+ daxpy_(&n, &alpha, Hd, &inc, r, &inc);
210
+ rnewTrnew = ddot_(&n, r, &inc, r, &inc);
211
+ beta = rnewTrnew/rTr;
212
+ dscal_(&n, &beta, d, &inc);
213
+ daxpy_(&n, &one, r, &inc, d, &inc);
214
+ rTr = rnewTrnew;
215
+ }
216
+
217
+ delete[] d;
218
+ delete[] Hd;
219
+
220
+ return(cg_iter);
221
+ }
222
+
223
+ double TRON::norm_inf(int n, double *x)
224
+ {
225
+ double dmax = fabs(x[0]);
226
+ for (int i=1; i<n; i++)
227
+ if (fabs(x[i]) >= dmax)
228
+ dmax = fabs(x[i]);
229
+ return(dmax);
230
+ }
231
+
232
+ void TRON::set_print_string(void (*print_string) (const char *buf))
233
+ {
234
+ tron_print_string = print_string;
235
+ }
@@ -0,0 +1,34 @@
1
+ #ifndef _TRON_H
2
+ #define _TRON_H
3
+
4
+ class function
5
+ {
6
+ public:
7
+ virtual double fun(double *w) = 0 ;
8
+ virtual void grad(double *w, double *g) = 0 ;
9
+ virtual void Hv(double *s, double *Hs) = 0 ;
10
+
11
+ virtual int get_nr_variable(void) = 0 ;
12
+ virtual ~function(void){}
13
+ };
14
+
15
+ class TRON
16
+ {
17
+ public:
18
+ TRON(const function *fun_obj, double eps = 0.1, int max_iter = 1000);
19
+ ~TRON();
20
+
21
+ void tron(double *w);
22
+ void set_print_string(void (*i_print) (const char *buf));
23
+
24
+ private:
25
+ int trcg(double delta, double *g, double *s, double *r);
26
+ double norm_inf(int n, double *x);
27
+
28
+ double eps;
29
+ int max_iter;
30
+ function *fun_obj;
31
+ void info(const char *fmt,...);
32
+ void (*tron_print_string)(const char *buf);
33
+ };
34
+ #endif
@@ -0,0 +1,11 @@
1
+ require 'rubylinear_native'
2
+
3
+ module RubyLinear
4
+ def self.validate_options(options)
5
+ raise ArgumentError, "A solver must be specified" unless options[:solver]
6
+ unknown_keys = options.keys - [:c, :solver, :eps, :weights]
7
+ if unknown_keys.any?
8
+ raise ArgumentError, "Unknown options: #{unknown_keys.inspect}"
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,187 @@
1
+ solver_type L1R_L2LOSS_SVC
2
+ nr_class 3
3
+ label 3 1 2
4
+ nr_feature 180
5
+ bias 1
6
+ w
7
+ 0.02013340696642403 0 -0.1571378238848434
8
+ 0.03749897313697532 0.02851810309740256 -0.08792709217307064
9
+ 0.182416942819516 0 -0.1629008738421322
10
+ 0.057398256185904 0.08587655132804607 -0.4021248537619269
11
+ -0.08440945299874639 0.04594758113861868 -0.324144012763583
12
+ 0.09626666880140594 -0.2013231822238345 -0.2020521658482208
13
+ 0.1931040870407536 0.1165262941019092 -0.2730273782988501
14
+ 0.2422178248127555 -0.2029517136885474 -0.1256037328887049
15
+ 0.2009360994942547 -0.06501527346239024 -0.2667453041189386
16
+ 0.113473515734091 -0.193577423835483 0.06393497637850488
17
+ -0.03188129365255857 0.07396690276090065 0.2179216904440836
18
+ -0.01844811371553047 0.000278102150848893 0
19
+ -0.2454698956511579 -0.04958829737707757 0.1374524909562953
20
+ -0.06384461175139515 -0.04794749625832508 0.1138723074173848
21
+ 0.06969443961289237 0 0
22
+ 0.08880262416744673 -0.03113542002473836 0
23
+ -0.02851344153266811 -0.03531206247888351 -0.310937248807243
24
+ 0.02731374566840788 -0.03331756586945989 -0.5012655094782823
25
+ 0.1556135142443686 -0.09609374854054094 -0.6690036211565307
26
+ 0.1527294433256928 -0.008028166013091862 -0.5057082644098619
27
+ 0.08207790770816108 -0.2463967887622368 -0.5701601996804387
28
+ -0.1529619302675301 0.04669608112101571 0.2584083647416214
29
+ -0.1977867353838474 0.1829972219855616 -0.04768726591707273
30
+ -0.2746248950245511 0.1486830364321789 0.3707325306451163
31
+ 0.09388798649794029 0.1043069683077304 -0.5388665768026363
32
+ 0.05515945698270306 -0.1238653927405989 -0.3111098032829153
33
+ 0.07662989834782494 0.0388171502657031 -0.2517686499082599
34
+ -0.196786935984455 0.1971861761411328 -0.2883905782386429
35
+ 0.06101685531234439 -0.06941051499959086 -0.1922521608149494
36
+ -0.01492102549442489 0.01143269145177065 -0.1938150388113903
37
+ 0.26054165944147 0 -0.3257780207028652
38
+ -0.00445195373070739 0.1625437658797168 0
39
+ 0.09977188407985588 0 0
40
+ 0.06075110495956209 0.01830573505843548 -0.03233466577314481
41
+ -0.004487703651053164 0.1482905964284915 0.04389193936356364
42
+ 0.04774688085823956 0.07106009328313824 -0.2272227991056337
43
+ 0.1367414387992451 -0.01348186810660379 0
44
+ 0 -0.09237169334683144 0.01202860091500364
45
+ -0.2577389391584256 0.0978028398408147 -0.1763851090165509
46
+ 0.03856182074880538 0.1333175792308643 -0.7715862763898076
47
+ -0.1539634745522554 -0.08947860535715692 -0.08748353324592795
48
+ 0 -0.06173275127171273 -0.6097348174798535
49
+ 0.1653577134130393 0 0
50
+ 0.146746091429147 -0.3158165489465554 0
51
+ 0.3550473443177195 -0.1115951760516682 -0.08148111110458925
52
+ 0.1789908673332288 0.04151220403330815 -0.2891189953559521
53
+ 0.1522653022519809 -0.1221033506004267 -0.1215766836878486
54
+ 0.334089236345716 -0.2479555293552262 -0.1222823927276005
55
+ 0.1029586560145411 -0.1253292845490155 -0.4573448902473595
56
+ -0.07859600124208306 -0.06951991870270835 0.04392653737766998
57
+ -0.04286471010654966 0.3653652800921977 -0.183588563356614
58
+ 0.2135009436133286 0.4282545827707505 -0.3455831493077505
59
+ 0.0316542204551418 0.1652090138838632 0
60
+ -0.01570459428047917 0.1714841275310023 -0.0002802095031544671
61
+ 0.4337040291345869 0.320477003252402 -1.07357944196913
62
+ 0.1542332388384909 0.2087156110048696 -0.06402992251132193
63
+ 0.3834587317179143 0.1169469158037457 -0.4836691854691754
64
+ 0.4292669882875256 -0.06021619121917744 -0.381172435948233
65
+ 0.004969053803552526 0 -0.003378540323624682
66
+ 0.3956972386906457 -0.1885203603544463 0.1559390909056178
67
+ 0.4809354243183575 -0.1249144417600274 -0.9425904267366738
68
+ 0.03284147106905205 -0.001497009592913744 -0.06031881502153554
69
+ 0.4263586416764014 0.1911953739612768 -0.4580304555848236
70
+ 0.464133419756847 0.1765901166072316 -0.5151050687893771
71
+ 0.04010075676388356 0 -0.05818335118371839
72
+ 0.2611256543151467 0.2889781582046403 -0.2085867779050407
73
+ 0.2513126114840114 0.318323482958084 -0.641593873059063
74
+ 0.1301407772137307 -0.1599519224791592 -0.3040625792400884
75
+ 0.4505680144858807 -0.1890432051835667 -0.4188869916521779
76
+ -0.01271520069385772 0.5751936541161782 -0.176309045123005
77
+ -0.08462461022139214 0.1533259088162128 0.3016877324584371
78
+ 0.2018803378915403 0.3699242120371078 -0.3817579870982373
79
+ 0.2571876257067033 0.06620044532424467 -1.059001159643563
80
+ -0.01190325363202492 0 -0.06773566453512289
81
+ 0.4154211463578731 0.3348395481257427 -0.7547281808817335
82
+ 0.2946725731168933 0.1454968708646255 -0.144636089871921
83
+ 0.2300875495240723 0.06665626874436811 -0.1746906182228936
84
+ 0.4129284582222247 0.08597284364492885 -0.1564748596272012
85
+ -0.03595831212662103 -0.3695369148729552 0.4505873987906854
86
+ 0.2020464200969789 -0.1607936883611799 0.2314311410144516
87
+ -0.0491558130016129 -0.3389104267018502 0.1866662336221878
88
+ 0.5471465589934924 0.5296582695406969 -2.026459774169788
89
+ -0.2505304783551446 0.5958066283114062 -0.5734756173503769
90
+ 0.9840493451886423 0.4792491321250367 -2.485954637331248
91
+ -1.444483133478516 0.07043466023549751 2.566119479163269
92
+ -0.2458475289339492 0 0.6571131360913476
93
+ -0.1199456491923447 0 -1.962860699813601
94
+ -0.1539055712409081 0 -0.09005550572989464
95
+ 0.3480411308156267 -0.2846573298791336 -0.3721573102648668
96
+ -1.643364123229578 0.9837664785513035 1.745459047818944
97
+ -0.2018132138798785 -0.6538354247538409 0.4084205962820638
98
+ 0.2946284097177052 -0.2042867403729732 -0.282048439265067
99
+ -0.7887045778181385 1.746969997328089 -0.1929226858477643
100
+ 1.121751200136391 -2.758274169154813 0.310019454614074
101
+ 1.215990132706368 -2.254155876531467 0.3536986271046261
102
+ 1.108844344247859 -2.157466206811264 0.4745704259376802
103
+ -0.4545913884495049 0.8362952811213272 -0.5334703682990538
104
+ 0.4339251007034062 -0.7513804425306917 0
105
+ -0.04214494936175231 0.1211825083077565 -0.1906580492446797
106
+ -0.7330237562105617 0.6642003663060955 -0.2986336544689749
107
+ -0.1217409878890608 0.3706677210073575 0
108
+ 0.18821418210195 -0.248983750036255 -0.0310344114799878
109
+ 0.05579139596485778 -0.2509772755540042 -0.08164957395236468
110
+ 0.2677863920823881 -0.405095519383235 -0.02071196261014919
111
+ -0.9727051582303718 1.04337071603934 -0.7344741437323482
112
+ 0.4160010328289956 -0.6428705183529503 0
113
+ 0.1421694406043006 -0.5432756197993849 0.3319197231702026
114
+ 0.2726953701813231 -0.478355571468085 0.2339226866536993
115
+ -0.2074701334947183 -0.07962269188220865 0.114985622339286
116
+ -0.08188956374825782 0.01512394197727574 0
117
+ -0.01476913689349868 0 -0.5063451574135208
118
+ 0.1281103713346617 0 0.2114326304744909
119
+ -0.01050070869108232 -0.1504384723061875 0.2604631939608084
120
+ 0.1597118847423477 -0.03345733067327621 -0.05144429534797917
121
+ -0.1712232469600758 -0.2156259055815165 0.2217404882305374
122
+ -0.04513937850294583 0.2420017067739638 0.0878606828053934
123
+ -0.04572112615625421 0.08754207396457059 -0.07831219754769082
124
+ 0.3819315686659829 -0.1179169566075425 -0.2837029031723398
125
+ 0.09189226458474982 0.007543947666165944 -0.2035181488565481
126
+ 0.1381751189063063 -0.007048167088425604 -0.1363772756063209
127
+ 0.1173982248079364 -0.4586988775702685 -0.05538156243409938
128
+ -0.3000035040449555 -0.2712393316875591 0.2967338757301871
129
+ -0.2459417103559974 0 0.05571590781124285
130
+ 0.03964949069745201 -0.1215114603841601 0
131
+ -0.00948109194715873 0 0
132
+ -0.09113993917249501 0.1472639863732914 0.1621662815644304
133
+ 0.1600758907482166 0.06328945737434939 0
134
+ -0.003087812558166637 -0.03649994915444396 -0.01764352018858131
135
+ 0.09096182339759093 0.04450465759416147 0.006425587254144259
136
+ 0.04158442261067564 -0.03859179687416574 -0.1412377622416668
137
+ -0.01963439558637189 0.03324638138398089 -0.330865912208441
138
+ -0.008482118808782163 0.05752153296183453 0.03179735844490969
139
+ 0.2032923274631153 -0.3625910712721387 0.1202796389488106
140
+ 0.1130430435113964 -0.1630359208276418 0.06884176254877679
141
+ 0.1889619632559745 -0.3956197312203688 -0.0029597288113893
142
+ 0.04089260621328007 -0.06209616467222634 -0.2985812272854372
143
+ -0.2525135546464782 -0.03933756404081787 0.181680996618776
144
+ -0.1357365130172946 -0.06744161495516253 -0.02869870843005774
145
+ 0.1296626482585284 -0.2067548017625899 0.1184905588709334
146
+ 0.1553963422195503 -0.1571463384816846 -0.2264514698434542
147
+ 0.009902447360973487 -0.1257420497777256 0
148
+ -0.0458971860219014 0 -0.0502433970046061
149
+ 0 0.3093972294532752 0
150
+ -0.2204179799238118 0.1854723881343433 0.06641274667557519
151
+ 0.1211876107986347 -0.2049494601602448 0.2505947291390709
152
+ -0.1336549914142491 -0.2924916887913175 0.3495434142063651
153
+ -0.005844025237753386 -0.2039439266927528 0.1753396216412218
154
+ 0.1310894322103897 -0.2073009043462201 -0.08487806748269781
155
+ -0.06460762927233279 0 0.006841907245008644
156
+ -0.199378161077782 0.02031740691915459 0.0851329177967203
157
+ 0 -0.3979996617928457 0.3604926496311117
158
+ -0.06085068708856548 -0.1404061160039149 -0.4616463361060925
159
+ -0.1294541021505957 0.04633752245745432 0.005584804764055991
160
+ 0.2496608792373075 0.03502558280529297 -0.1757728139373698
161
+ 0.09351280488248936 0.102078816596968 -0.3209209533328811
162
+ 0.1831037211993164 0.08501136740961888 -0.2825698589850296
163
+ 0.1657967510758501 -0.4256414833394009 0.1105748622355889
164
+ 0.03847201960984974 -0.2206551577790529 0.04186567315108874
165
+ 0.07710791994437377 0 -0.1801099834206085
166
+ 0.03002236419097492 -0.4343699397244553 -0.5196292369575694
167
+ 0.02545471991064811 -0.2657167943538667 -0.1605867575059343
168
+ 0.1362207830781751 -0.2900333653889114 -0.1189333747342645
169
+ -0.14959546181405 0.01797429636035943 0.1521227647150173
170
+ -0.04570281653466469 0 -0.05820737348717202
171
+ -0.1348056623278165 0 0.05229772010317342
172
+ 0.3281233649791613 0 0
173
+ -0.02248038838204792 -0.120965915287167 0.386835874433984
174
+ -0.2039482359853235 0.2528833480381093 -0.06293340531708454
175
+ -0.07085329146928954 -0.259162966169562 0.1436965820865992
176
+ 0 -0.09343577104520842 -0.01353592369683937
177
+ -0.2250225823106939 -0.1328242549716434 0.07253595454785078
178
+ -0.3558321345092992 -0.07658887126878239 0.6765789760350654
179
+ -0.3352773654481199 0.05013381043045785 0.5582573473553918
180
+ -0.2203945406918789 -0.009204485943198212 0.4504666411769261
181
+ 0.1649848364397697 0.2778268696480816 -0.09627317170991107
182
+ 0.2332502137518346 0 -0.2238032636728753
183
+ 0.1370556997200548 0 0.006533270533938317
184
+ 0.04667430265022181 -0.4499785508716155 0.08367564177614155
185
+ 0.1527953737850639 -0.4229172255726032 0
186
+ -0.1307416480886588 -0.03141958170964223 0.05829352255374286
187
+ 0.1888037445113804 -2.236609027510631 -0.06026513284159574