alglib 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (255) hide show
  1. data/History.txt +7 -0
  2. data/Manifest.txt +253 -0
  3. data/README.txt +33 -0
  4. data/Rakefile +27 -0
  5. data/ext/Rakefile +24 -0
  6. data/ext/alglib.i +24 -0
  7. data/ext/alglib/Makefile +157 -0
  8. data/ext/alglib/airyf.cpp +372 -0
  9. data/ext/alglib/airyf.h +81 -0
  10. data/ext/alglib/alglib.cpp +8558 -0
  11. data/ext/alglib/alglib_util.cpp +19 -0
  12. data/ext/alglib/alglib_util.h +14 -0
  13. data/ext/alglib/ap.cpp +877 -0
  14. data/ext/alglib/ap.english.html +364 -0
  15. data/ext/alglib/ap.h +666 -0
  16. data/ext/alglib/ap.russian.html +442 -0
  17. data/ext/alglib/apvt.h +754 -0
  18. data/ext/alglib/bdss.cpp +1500 -0
  19. data/ext/alglib/bdss.h +251 -0
  20. data/ext/alglib/bdsvd.cpp +1339 -0
  21. data/ext/alglib/bdsvd.h +164 -0
  22. data/ext/alglib/bessel.cpp +1226 -0
  23. data/ext/alglib/bessel.h +331 -0
  24. data/ext/alglib/betaf.cpp +105 -0
  25. data/ext/alglib/betaf.h +74 -0
  26. data/ext/alglib/bidiagonal.cpp +1328 -0
  27. data/ext/alglib/bidiagonal.h +350 -0
  28. data/ext/alglib/binomialdistr.cpp +247 -0
  29. data/ext/alglib/binomialdistr.h +153 -0
  30. data/ext/alglib/blas.cpp +576 -0
  31. data/ext/alglib/blas.h +132 -0
  32. data/ext/alglib/cblas.cpp +226 -0
  33. data/ext/alglib/cblas.h +57 -0
  34. data/ext/alglib/cdet.cpp +138 -0
  35. data/ext/alglib/cdet.h +92 -0
  36. data/ext/alglib/chebyshev.cpp +216 -0
  37. data/ext/alglib/chebyshev.h +76 -0
  38. data/ext/alglib/chisquaredistr.cpp +157 -0
  39. data/ext/alglib/chisquaredistr.h +144 -0
  40. data/ext/alglib/cholesky.cpp +285 -0
  41. data/ext/alglib/cholesky.h +86 -0
  42. data/ext/alglib/cinverse.cpp +298 -0
  43. data/ext/alglib/cinverse.h +111 -0
  44. data/ext/alglib/clu.cpp +337 -0
  45. data/ext/alglib/clu.h +120 -0
  46. data/ext/alglib/correlation.cpp +280 -0
  47. data/ext/alglib/correlation.h +77 -0
  48. data/ext/alglib/correlationtests.cpp +726 -0
  49. data/ext/alglib/correlationtests.h +134 -0
  50. data/ext/alglib/crcond.cpp +826 -0
  51. data/ext/alglib/crcond.h +148 -0
  52. data/ext/alglib/creflections.cpp +310 -0
  53. data/ext/alglib/creflections.h +165 -0
  54. data/ext/alglib/csolve.cpp +312 -0
  55. data/ext/alglib/csolve.h +99 -0
  56. data/ext/alglib/ctrinverse.cpp +387 -0
  57. data/ext/alglib/ctrinverse.h +98 -0
  58. data/ext/alglib/ctrlinsolve.cpp +297 -0
  59. data/ext/alglib/ctrlinsolve.h +81 -0
  60. data/ext/alglib/dawson.cpp +234 -0
  61. data/ext/alglib/dawson.h +74 -0
  62. data/ext/alglib/descriptivestatistics.cpp +436 -0
  63. data/ext/alglib/descriptivestatistics.h +112 -0
  64. data/ext/alglib/det.cpp +140 -0
  65. data/ext/alglib/det.h +94 -0
  66. data/ext/alglib/dforest.cpp +1819 -0
  67. data/ext/alglib/dforest.h +316 -0
  68. data/ext/alglib/elliptic.cpp +497 -0
  69. data/ext/alglib/elliptic.h +217 -0
  70. data/ext/alglib/estnorm.cpp +429 -0
  71. data/ext/alglib/estnorm.h +107 -0
  72. data/ext/alglib/expintegrals.cpp +422 -0
  73. data/ext/alglib/expintegrals.h +108 -0
  74. data/ext/alglib/faq.english.html +258 -0
  75. data/ext/alglib/faq.russian.html +272 -0
  76. data/ext/alglib/fdistr.cpp +202 -0
  77. data/ext/alglib/fdistr.h +163 -0
  78. data/ext/alglib/fresnel.cpp +211 -0
  79. data/ext/alglib/fresnel.h +91 -0
  80. data/ext/alglib/gammaf.cpp +338 -0
  81. data/ext/alglib/gammaf.h +104 -0
  82. data/ext/alglib/gqgengauss.cpp +235 -0
  83. data/ext/alglib/gqgengauss.h +92 -0
  84. data/ext/alglib/gqgenhermite.cpp +268 -0
  85. data/ext/alglib/gqgenhermite.h +63 -0
  86. data/ext/alglib/gqgenjacobi.cpp +297 -0
  87. data/ext/alglib/gqgenjacobi.h +72 -0
  88. data/ext/alglib/gqgenlaguerre.cpp +265 -0
  89. data/ext/alglib/gqgenlaguerre.h +69 -0
  90. data/ext/alglib/gqgenlegendre.cpp +300 -0
  91. data/ext/alglib/gqgenlegendre.h +62 -0
  92. data/ext/alglib/gqgenlobatto.cpp +305 -0
  93. data/ext/alglib/gqgenlobatto.h +97 -0
  94. data/ext/alglib/gqgenradau.cpp +232 -0
  95. data/ext/alglib/gqgenradau.h +95 -0
  96. data/ext/alglib/hbisinv.cpp +480 -0
  97. data/ext/alglib/hbisinv.h +183 -0
  98. data/ext/alglib/hblas.cpp +228 -0
  99. data/ext/alglib/hblas.h +64 -0
  100. data/ext/alglib/hcholesky.cpp +339 -0
  101. data/ext/alglib/hcholesky.h +91 -0
  102. data/ext/alglib/hermite.cpp +114 -0
  103. data/ext/alglib/hermite.h +49 -0
  104. data/ext/alglib/hessenberg.cpp +370 -0
  105. data/ext/alglib/hessenberg.h +152 -0
  106. data/ext/alglib/hevd.cpp +247 -0
  107. data/ext/alglib/hevd.h +107 -0
  108. data/ext/alglib/hsschur.cpp +1316 -0
  109. data/ext/alglib/hsschur.h +108 -0
  110. data/ext/alglib/htridiagonal.cpp +734 -0
  111. data/ext/alglib/htridiagonal.h +180 -0
  112. data/ext/alglib/ialglib.cpp +6 -0
  113. data/ext/alglib/ialglib.h +9 -0
  114. data/ext/alglib/ibetaf.cpp +960 -0
  115. data/ext/alglib/ibetaf.h +125 -0
  116. data/ext/alglib/igammaf.cpp +430 -0
  117. data/ext/alglib/igammaf.h +157 -0
  118. data/ext/alglib/inv.cpp +274 -0
  119. data/ext/alglib/inv.h +115 -0
  120. data/ext/alglib/inverseupdate.cpp +480 -0
  121. data/ext/alglib/inverseupdate.h +185 -0
  122. data/ext/alglib/jacobianelliptic.cpp +164 -0
  123. data/ext/alglib/jacobianelliptic.h +94 -0
  124. data/ext/alglib/jarquebera.cpp +2271 -0
  125. data/ext/alglib/jarquebera.h +80 -0
  126. data/ext/alglib/kmeans.cpp +356 -0
  127. data/ext/alglib/kmeans.h +76 -0
  128. data/ext/alglib/laguerre.cpp +94 -0
  129. data/ext/alglib/laguerre.h +48 -0
  130. data/ext/alglib/lbfgs.cpp +1167 -0
  131. data/ext/alglib/lbfgs.h +218 -0
  132. data/ext/alglib/lda.cpp +434 -0
  133. data/ext/alglib/lda.h +133 -0
  134. data/ext/alglib/ldlt.cpp +1130 -0
  135. data/ext/alglib/ldlt.h +124 -0
  136. data/ext/alglib/leastsquares.cpp +1252 -0
  137. data/ext/alglib/leastsquares.h +290 -0
  138. data/ext/alglib/legendre.cpp +107 -0
  139. data/ext/alglib/legendre.h +49 -0
  140. data/ext/alglib/linreg.cpp +1185 -0
  141. data/ext/alglib/linreg.h +380 -0
  142. data/ext/alglib/logit.cpp +1523 -0
  143. data/ext/alglib/logit.h +333 -0
  144. data/ext/alglib/lq.cpp +399 -0
  145. data/ext/alglib/lq.h +160 -0
  146. data/ext/alglib/lu.cpp +462 -0
  147. data/ext/alglib/lu.h +119 -0
  148. data/ext/alglib/mannwhitneyu.cpp +4490 -0
  149. data/ext/alglib/mannwhitneyu.h +115 -0
  150. data/ext/alglib/minlm.cpp +918 -0
  151. data/ext/alglib/minlm.h +312 -0
  152. data/ext/alglib/mlpbase.cpp +3375 -0
  153. data/ext/alglib/mlpbase.h +589 -0
  154. data/ext/alglib/mlpe.cpp +1369 -0
  155. data/ext/alglib/mlpe.h +552 -0
  156. data/ext/alglib/mlptrain.cpp +1056 -0
  157. data/ext/alglib/mlptrain.h +283 -0
  158. data/ext/alglib/nearunityunit.cpp +91 -0
  159. data/ext/alglib/nearunityunit.h +17 -0
  160. data/ext/alglib/normaldistr.cpp +377 -0
  161. data/ext/alglib/normaldistr.h +175 -0
  162. data/ext/alglib/nsevd.cpp +1869 -0
  163. data/ext/alglib/nsevd.h +140 -0
  164. data/ext/alglib/pca.cpp +168 -0
  165. data/ext/alglib/pca.h +87 -0
  166. data/ext/alglib/poissondistr.cpp +143 -0
  167. data/ext/alglib/poissondistr.h +130 -0
  168. data/ext/alglib/polinterpolation.cpp +685 -0
  169. data/ext/alglib/polinterpolation.h +206 -0
  170. data/ext/alglib/psif.cpp +173 -0
  171. data/ext/alglib/psif.h +88 -0
  172. data/ext/alglib/qr.cpp +414 -0
  173. data/ext/alglib/qr.h +168 -0
  174. data/ext/alglib/ratinterpolation.cpp +134 -0
  175. data/ext/alglib/ratinterpolation.h +72 -0
  176. data/ext/alglib/rcond.cpp +705 -0
  177. data/ext/alglib/rcond.h +140 -0
  178. data/ext/alglib/reflections.cpp +504 -0
  179. data/ext/alglib/reflections.h +165 -0
  180. data/ext/alglib/rotations.cpp +473 -0
  181. data/ext/alglib/rotations.h +128 -0
  182. data/ext/alglib/rsolve.cpp +221 -0
  183. data/ext/alglib/rsolve.h +99 -0
  184. data/ext/alglib/sbisinv.cpp +217 -0
  185. data/ext/alglib/sbisinv.h +171 -0
  186. data/ext/alglib/sblas.cpp +185 -0
  187. data/ext/alglib/sblas.h +64 -0
  188. data/ext/alglib/schur.cpp +156 -0
  189. data/ext/alglib/schur.h +102 -0
  190. data/ext/alglib/sdet.cpp +193 -0
  191. data/ext/alglib/sdet.h +101 -0
  192. data/ext/alglib/sevd.cpp +116 -0
  193. data/ext/alglib/sevd.h +99 -0
  194. data/ext/alglib/sinverse.cpp +672 -0
  195. data/ext/alglib/sinverse.h +138 -0
  196. data/ext/alglib/spddet.cpp +138 -0
  197. data/ext/alglib/spddet.h +96 -0
  198. data/ext/alglib/spdgevd.cpp +842 -0
  199. data/ext/alglib/spdgevd.h +200 -0
  200. data/ext/alglib/spdinverse.cpp +509 -0
  201. data/ext/alglib/spdinverse.h +122 -0
  202. data/ext/alglib/spdrcond.cpp +421 -0
  203. data/ext/alglib/spdrcond.h +118 -0
  204. data/ext/alglib/spdsolve.cpp +275 -0
  205. data/ext/alglib/spdsolve.h +105 -0
  206. data/ext/alglib/spline2d.cpp +1192 -0
  207. data/ext/alglib/spline2d.h +301 -0
  208. data/ext/alglib/spline3.cpp +1264 -0
  209. data/ext/alglib/spline3.h +290 -0
  210. data/ext/alglib/srcond.cpp +595 -0
  211. data/ext/alglib/srcond.h +127 -0
  212. data/ext/alglib/ssolve.cpp +895 -0
  213. data/ext/alglib/ssolve.h +139 -0
  214. data/ext/alglib/stdafx.h +0 -0
  215. data/ext/alglib/stest.cpp +131 -0
  216. data/ext/alglib/stest.h +94 -0
  217. data/ext/alglib/studenttdistr.cpp +222 -0
  218. data/ext/alglib/studenttdistr.h +115 -0
  219. data/ext/alglib/studentttests.cpp +377 -0
  220. data/ext/alglib/studentttests.h +178 -0
  221. data/ext/alglib/svd.cpp +620 -0
  222. data/ext/alglib/svd.h +126 -0
  223. data/ext/alglib/tdbisinv.cpp +2608 -0
  224. data/ext/alglib/tdbisinv.h +228 -0
  225. data/ext/alglib/tdevd.cpp +1229 -0
  226. data/ext/alglib/tdevd.h +115 -0
  227. data/ext/alglib/tridiagonal.cpp +594 -0
  228. data/ext/alglib/tridiagonal.h +171 -0
  229. data/ext/alglib/trigintegrals.cpp +490 -0
  230. data/ext/alglib/trigintegrals.h +131 -0
  231. data/ext/alglib/trinverse.cpp +345 -0
  232. data/ext/alglib/trinverse.h +98 -0
  233. data/ext/alglib/trlinsolve.cpp +926 -0
  234. data/ext/alglib/trlinsolve.h +73 -0
  235. data/ext/alglib/tsort.cpp +405 -0
  236. data/ext/alglib/tsort.h +54 -0
  237. data/ext/alglib/variancetests.cpp +245 -0
  238. data/ext/alglib/variancetests.h +134 -0
  239. data/ext/alglib/wsr.cpp +6285 -0
  240. data/ext/alglib/wsr.h +96 -0
  241. data/ext/ap.i +97 -0
  242. data/ext/correlation.i +24 -0
  243. data/ext/extconf.rb +6 -0
  244. data/ext/logit.i +89 -0
  245. data/lib/alglib.rb +71 -0
  246. data/lib/alglib/correlation.rb +26 -0
  247. data/lib/alglib/linearregression.rb +63 -0
  248. data/lib/alglib/logit.rb +42 -0
  249. data/test/test_alglib.rb +52 -0
  250. data/test/test_correlation.rb +44 -0
  251. data/test/test_correlationtest.rb +45 -0
  252. data/test/test_linreg.rb +35 -0
  253. data/test/test_logit.rb +43 -0
  254. data/test/test_pca.rb +27 -0
  255. metadata +326 -0
@@ -0,0 +1,1056 @@
1
+ /*************************************************************************
2
+ Copyright (c) 2007-2008, Sergey Bochkanov (ALGLIB project).
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are
6
+ met:
7
+
8
+ - Redistributions of source code must retain the above copyright
9
+ notice, this list of conditions and the following disclaimer.
10
+
11
+ - Redistributions in binary form must reproduce the above copyright
12
+ notice, this list of conditions and the following disclaimer listed
13
+ in this license in the documentation and/or other materials
14
+ provided with the distribution.
15
+
16
+ - Neither the name of the copyright holders nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ *************************************************************************/
32
+
33
+ #include <stdafx.h>
34
+ #include "mlptrain.h"
35
+
36
+ static const double mindecay = 0.001;
37
+
38
+ static void mlpkfoldcvgeneral(const multilayerperceptron& n,
39
+ const ap::real_2d_array& xy,
40
+ int npoints,
41
+ double decay,
42
+ int restarts,
43
+ int foldscount,
44
+ bool lmalgorithm,
45
+ double wstep,
46
+ int maxits,
47
+ int& info,
48
+ mlpreport& rep,
49
+ mlpcvreport& cvrep);
50
+ static void mlpkfoldsplit(const ap::real_2d_array& xy,
51
+ int npoints,
52
+ int nclasses,
53
+ int foldscount,
54
+ bool stratifiedsplits,
55
+ ap::integer_1d_array& folds);
56
+
57
+ /*************************************************************************
58
+ Neural network training using modified Levenberg-Marquardt with exact
59
+ Hessian calculation and regularization. Subroutine trains neural network
60
+ with restarts from random positions. Algorithm is well suited for small
61
+ and medium scale problems (hundreds of weights).
62
+
63
+ INPUT PARAMETERS:
64
+ Network - neural network with initialized geometry
65
+ XY - training set
66
+ NPoints - training set size
67
+ Decay - weight decay constant, >=0.001
68
+ Decay term 'Decay*||Weights||^2' is added to error
69
+ function.
70
+ If you don't know what Decay to choose, use 0.001.
71
+ Restarts - number of restarts from random position, >0.
72
+ If you don't know what Restarts to choose, use 2.
73
+
74
+ OUTPUT PARAMETERS:
75
+ Network - trained neural network.
76
+ Info - return code:
77
+ * -9, if internal matrix inverse subroutine failed
78
+ * -2, if there is a point with class number
79
+ outside of [0..NOut-1].
80
+ * -1, if wrong parameters specified
81
+ (NPoints<0, Restarts<1).
82
+ * 2, if task has been solved.
83
+ Rep - training report
84
+
85
+ -- ALGLIB --
86
+ Copyright 10.03.2009 by Bochkanov Sergey
87
+ *************************************************************************/
88
+ void mlptrainlm(multilayerperceptron& network,
89
+ const ap::real_2d_array& xy,
90
+ int npoints,
91
+ double decay,
92
+ int restarts,
93
+ int& info,
94
+ mlpreport& rep)
95
+ {
96
+ int nin;
97
+ int nout;
98
+ int wcount;
99
+ double lmftol;
100
+ double lmsteptol;
101
+ int i;
102
+ int j;
103
+ int k;
104
+ int mx;
105
+ double v;
106
+ double e;
107
+ double enew;
108
+ double xnorm2;
109
+ double stepnorm;
110
+ ap::real_1d_array g;
111
+ ap::real_1d_array d;
112
+ ap::real_2d_array h;
113
+ ap::real_2d_array hmod;
114
+ ap::real_2d_array z;
115
+ bool spd;
116
+ double nu;
117
+ double lambda;
118
+ double lambdaup;
119
+ double lambdadown;
120
+ int cvcnt;
121
+ double cvrelcnt;
122
+ lbfgsreport internalrep;
123
+ lbfgsstate state;
124
+ ap::real_1d_array x;
125
+ ap::real_1d_array y;
126
+ ap::real_1d_array wbase;
127
+ double wstep;
128
+ ap::real_1d_array wdir;
129
+ ap::real_1d_array wt;
130
+ ap::real_1d_array wx;
131
+ int pass;
132
+ ap::real_1d_array wbest;
133
+ double ebest;
134
+
135
+ mlpproperties(network, nin, nout, wcount);
136
+ lambdaup = 10;
137
+ lambdadown = 0.3;
138
+ lmftol = 0.001;
139
+ lmsteptol = 0.001;
140
+
141
+ //
142
+ // Test for inputs
143
+ //
144
+ if( npoints<=0||restarts<1 )
145
+ {
146
+ info = -1;
147
+ return;
148
+ }
149
+ if( mlpissoftmax(network) )
150
+ {
151
+ for(i = 0; i <= npoints-1; i++)
152
+ {
153
+ if( ap::round(xy(i,nin))<0||ap::round(xy(i,nin))>=nout )
154
+ {
155
+ info = -2;
156
+ return;
157
+ }
158
+ }
159
+ }
160
+ decay = ap::maxreal(decay, mindecay);
161
+ info = 2;
162
+
163
+ //
164
+ // Initialize data
165
+ //
166
+ rep.ngrad = 0;
167
+ rep.nhess = 0;
168
+ rep.ncholesky = 0;
169
+
170
+ //
171
+ // General case.
172
+ // Prepare task and network. Allocate space.
173
+ //
174
+ mlpinitpreprocessor(network, xy, npoints);
175
+ g.setbounds(0, wcount-1);
176
+ h.setbounds(0, wcount-1, 0, wcount-1);
177
+ hmod.setbounds(0, wcount-1, 0, wcount-1);
178
+ wbase.setbounds(0, wcount-1);
179
+ wdir.setbounds(0, wcount-1);
180
+ wbest.setbounds(0, wcount-1);
181
+ wt.setbounds(0, wcount-1);
182
+ wx.setbounds(0, wcount-1);
183
+ ebest = ap::maxrealnumber;
184
+
185
+ //
186
+ // Multiple passes
187
+ //
188
+ for(pass = 1; pass <= restarts; pass++)
189
+ {
190
+
191
+ //
192
+ // Initialize weights
193
+ //
194
+ mlprandomize(network);
195
+
196
+ //
197
+ // First stage of the hybrid algorithm: LBFGS
198
+ //
199
+ ap::vmove(&wbase(0), &network.weights(0), ap::vlen(0,wcount-1));
200
+ minlbfgs(wcount, ap::minint(wcount, 5), wbase, 0.0, 0.0, 0.0, ap::maxint(25, wcount), 0, state);
201
+ while(minlbfgsiteration(state))
202
+ {
203
+
204
+ //
205
+ // gradient
206
+ //
207
+ ap::vmove(&network.weights(0), &state.x(0), ap::vlen(0,wcount-1));
208
+ mlpgradbatch(network, xy, npoints, state.f, state.g);
209
+
210
+ //
211
+ // weight decay
212
+ //
213
+ v = ap::vdotproduct(&network.weights(0), &network.weights(0), ap::vlen(0,wcount-1));
214
+ state.f = state.f+0.5*decay*v;
215
+ ap::vadd(&state.g(0), &network.weights(0), ap::vlen(0,wcount-1), decay);
216
+
217
+ //
218
+ // next iteration
219
+ //
220
+ rep.ngrad = rep.ngrad+1;
221
+ }
222
+ minlbfgsresults(state, wbase, internalrep);
223
+ ap::vmove(&network.weights(0), &wbase(0), ap::vlen(0,wcount-1));
224
+
225
+ //
226
+ // Second stage of the hybrid algorithm: LM
227
+ //
228
+ // Initialize H with identity matrix,
229
+ // G with gradient,
230
+ // E with regularized error.
231
+ //
232
+ mlphessianbatch(network, xy, npoints, e, g, h);
233
+ v = ap::vdotproduct(&network.weights(0), &network.weights(0), ap::vlen(0,wcount-1));
234
+ e = e+0.5*decay*v;
235
+ ap::vadd(&g(0), &network.weights(0), ap::vlen(0,wcount-1), decay);
236
+ for(k = 0; k <= wcount-1; k++)
237
+ {
238
+ h(k,k) = h(k,k)+decay;
239
+ }
240
+ rep.nhess = rep.nhess+1;
241
+ lambda = 0.001;
242
+ nu = 2;
243
+ while(true)
244
+ {
245
+
246
+ //
247
+ // 1. HMod = H+lambda*I
248
+ // 2. Try to solve (H+Lambda*I)*dx = -g.
249
+ // Increase lambda if left part is not positive definite.
250
+ //
251
+ for(i = 0; i <= wcount-1; i++)
252
+ {
253
+ ap::vmove(&hmod(i, 0), &h(i, 0), ap::vlen(0,wcount-1));
254
+ hmod(i,i) = hmod(i,i)+lambda;
255
+ }
256
+ spd = spdmatrixcholesky(hmod, wcount, true);
257
+ rep.ncholesky = rep.ncholesky+1;
258
+ if( !spd )
259
+ {
260
+ lambda = lambda*lambdaup*nu;
261
+ nu = nu*2;
262
+ continue;
263
+ }
264
+ if( !spdmatrixcholeskysolve(hmod, g, wcount, true, wdir) )
265
+ {
266
+ lambda = lambda*lambdaup*nu;
267
+ nu = nu*2;
268
+ continue;
269
+ }
270
+ ap::vmul(&wdir(0), ap::vlen(0,wcount-1), -1);
271
+
272
+ //
273
+ // Lambda found.
274
+ // 1. Save old w in WBase
275
+ // 1. Test some stopping criterions
276
+ // 2. If error(w+wdir)>error(w), increase lambda
277
+ //
278
+ ap::vadd(&network.weights(0), &wdir(0), ap::vlen(0,wcount-1));
279
+ xnorm2 = ap::vdotproduct(&network.weights(0), &network.weights(0), ap::vlen(0,wcount-1));
280
+ stepnorm = ap::vdotproduct(&wdir(0), &wdir(0), ap::vlen(0,wcount-1));
281
+ stepnorm = sqrt(stepnorm);
282
+ enew = mlperror(network, xy, npoints)+0.5*decay*xnorm2;
283
+ if( stepnorm<lmsteptol*(1+sqrt(xnorm2)) )
284
+ {
285
+ break;
286
+ }
287
+ if( enew>e )
288
+ {
289
+ lambda = lambda*lambdaup*nu;
290
+ nu = nu*2;
291
+ continue;
292
+ }
293
+
294
+ //
295
+ // Optimize using inv(cholesky(H)) as preconditioner
296
+ //
297
+ if( !rmatrixtrinverse(hmod, wcount, true, false) )
298
+ {
299
+
300
+ //
301
+ // if matrix can't be inverted then exit with errors
302
+ // TODO: make WCount steps in direction suggested by HMod
303
+ //
304
+ info = -9;
305
+ return;
306
+ }
307
+ ap::vmove(&wbase(0), &network.weights(0), ap::vlen(0,wcount-1));
308
+ for(i = 0; i <= wcount-1; i++)
309
+ {
310
+ wt(i) = 0;
311
+ }
312
+ minlbfgs(wcount, wcount, wt, 0.0, 0.0, 0.0, 5, 0, state);
313
+ while(minlbfgsiteration(state))
314
+ {
315
+
316
+ //
317
+ // gradient
318
+ //
319
+ for(i = 0; i <= wcount-1; i++)
320
+ {
321
+ v = ap::vdotproduct(&state.x(i), &hmod(i, i), ap::vlen(i,wcount-1));
322
+ network.weights(i) = wbase(i)+v;
323
+ }
324
+ mlpgradbatch(network, xy, npoints, state.f, g);
325
+ for(i = 0; i <= wcount-1; i++)
326
+ {
327
+ state.g(i) = 0;
328
+ }
329
+ for(i = 0; i <= wcount-1; i++)
330
+ {
331
+ v = g(i);
332
+ ap::vadd(&state.g(i), &hmod(i, i), ap::vlen(i,wcount-1), v);
333
+ }
334
+
335
+ //
336
+ // weight decay
337
+ // grad(x'*x) = A'*(x0+A*t)
338
+ //
339
+ v = ap::vdotproduct(&network.weights(0), &network.weights(0), ap::vlen(0,wcount-1));
340
+ state.f = state.f+0.5*decay*v;
341
+ for(i = 0; i <= wcount-1; i++)
342
+ {
343
+ v = decay*network.weights(i);
344
+ ap::vadd(&state.g(i), &hmod(i, i), ap::vlen(i,wcount-1), v);
345
+ }
346
+
347
+ //
348
+ // next iteration
349
+ //
350
+ rep.ngrad = rep.ngrad+1;
351
+ }
352
+ minlbfgsresults(state, wt, internalrep);
353
+
354
+ //
355
+ // Accept new position.
356
+ // Calculate Hessian
357
+ //
358
+ for(i = 0; i <= wcount-1; i++)
359
+ {
360
+ v = ap::vdotproduct(&wt(i), &hmod(i, i), ap::vlen(i,wcount-1));
361
+ network.weights(i) = wbase(i)+v;
362
+ }
363
+ mlphessianbatch(network, xy, npoints, e, g, h);
364
+ v = ap::vdotproduct(&network.weights(0), &network.weights(0), ap::vlen(0,wcount-1));
365
+ e = e+0.5*decay*v;
366
+ ap::vadd(&g(0), &network.weights(0), ap::vlen(0,wcount-1), decay);
367
+ for(k = 0; k <= wcount-1; k++)
368
+ {
369
+ h(k,k) = h(k,k)+decay;
370
+ }
371
+ rep.nhess = rep.nhess+1;
372
+
373
+ //
374
+ // Update lambda
375
+ //
376
+ lambda = lambda*lambdadown;
377
+ nu = 2;
378
+ }
379
+
380
+ //
381
+ // update WBest
382
+ //
383
+ v = ap::vdotproduct(&network.weights(0), &network.weights(0), ap::vlen(0,wcount-1));
384
+ e = 0.5*decay*v+mlperror(network, xy, npoints);
385
+ if( e<ebest )
386
+ {
387
+ ebest = e;
388
+ ap::vmove(&wbest(0), &network.weights(0), ap::vlen(0,wcount-1));
389
+ }
390
+ }
391
+
392
+ //
393
+ // copy WBest to output
394
+ //
395
+ ap::vmove(&network.weights(0), &wbest(0), ap::vlen(0,wcount-1));
396
+ }
397
+
398
+
399
+ /*************************************************************************
400
+ Neural network training using L-BFGS algorithm with regularization.
401
+ Subroutine trains neural network with restarts from random positions.
402
+ Algorithm is well suited for problems of any dimensionality (memory
403
+ requirements and step complexity are linear by weights number).
404
+
405
+ INPUT PARAMETERS:
406
+ Network - neural network with initialized geometry
407
+ XY - training set
408
+ NPoints - training set size
409
+ Decay - weight decay constant, >=0.001
410
+ Decay term 'Decay*||Weights||^2' is added to error
411
+ function.
412
+ If you don't know what Decay to choose, use 0.001.
413
+ Restarts - number of restarts from random position, >0.
414
+ If you don't know what Restarts to choose, use 2.
415
+ WStep - stopping criterion. Algorithm stops if step size is
416
+ less than WStep. Recommended value - 0.01. Zero step
417
+ size means stopping after MaxIts iterations.
418
+ MaxIts - stopping criterion. Algorithm stops after MaxIts
419
+ iterations (NOT gradient calculations). Zero MaxIts
420
+ means stopping when step is sufficiently small.
421
+
422
+ OUTPUT PARAMETERS:
423
+ Network - trained neural network.
424
+ Info - return code:
425
+ * -8, if both WStep=0 and MaxIts=0
426
+ * -2, if there is a point with class number
427
+ outside of [0..NOut-1].
428
+ * -1, if wrong parameters specified
429
+ (NPoints<0, Restarts<1).
430
+ * 2, if task has been solved.
431
+ Rep - training report
432
+
433
+ -- ALGLIB --
434
+ Copyright 09.12.2007 by Bochkanov Sergey
435
+ *************************************************************************/
436
+ void mlptrainlbfgs(multilayerperceptron& network,
437
+ const ap::real_2d_array& xy,
438
+ int npoints,
439
+ double decay,
440
+ int restarts,
441
+ double wstep,
442
+ int maxits,
443
+ int& info,
444
+ mlpreport& rep)
445
+ {
446
+ int i;
447
+ int j;
448
+ int pass;
449
+ int nin;
450
+ int nout;
451
+ int wcount;
452
+ ap::real_1d_array w;
453
+ ap::real_1d_array wbest;
454
+ double e;
455
+ double v;
456
+ double ebest;
457
+ lbfgsreport internalrep;
458
+ lbfgsstate state;
459
+
460
+
461
+ //
462
+ // Test inputs, parse flags, read network geometry
463
+ //
464
+ if( wstep==0&&maxits==0 )
465
+ {
466
+ info = -8;
467
+ return;
468
+ }
469
+ if( npoints<=0||restarts<1||wstep<0||maxits<0 )
470
+ {
471
+ info = -1;
472
+ return;
473
+ }
474
+ mlpproperties(network, nin, nout, wcount);
475
+ if( mlpissoftmax(network) )
476
+ {
477
+ for(i = 0; i <= npoints-1; i++)
478
+ {
479
+ if( ap::round(xy(i,nin))<0||ap::round(xy(i,nin))>=nout )
480
+ {
481
+ info = -2;
482
+ return;
483
+ }
484
+ }
485
+ }
486
+ decay = ap::maxreal(decay, mindecay);
487
+ info = 2;
488
+
489
+ //
490
+ // Prepare
491
+ //
492
+ mlpinitpreprocessor(network, xy, npoints);
493
+ w.setbounds(0, wcount-1);
494
+ wbest.setbounds(0, wcount-1);
495
+ ebest = ap::maxrealnumber;
496
+
497
+ //
498
+ // Multiple starts
499
+ //
500
+ rep.ncholesky = 0;
501
+ rep.nhess = 0;
502
+ rep.ngrad = 0;
503
+ for(pass = 1; pass <= restarts; pass++)
504
+ {
505
+
506
+ //
507
+ // Process
508
+ //
509
+ mlprandomize(network);
510
+ ap::vmove(&w(0), &network.weights(0), ap::vlen(0,wcount-1));
511
+ minlbfgs(wcount, ap::minint(wcount, 50), w, 0.0, 0.0, wstep, maxits, 0, state);
512
+ while(minlbfgsiteration(state))
513
+ {
514
+ ap::vmove(&network.weights(0), &state.x(0), ap::vlen(0,wcount-1));
515
+ mlpgradnbatch(network, xy, npoints, state.f, state.g);
516
+ v = ap::vdotproduct(&network.weights(0), &network.weights(0), ap::vlen(0,wcount-1));
517
+ state.f = state.f+0.5*decay*v;
518
+ ap::vadd(&state.g(0), &network.weights(0), ap::vlen(0,wcount-1), decay);
519
+ rep.ngrad = rep.ngrad+1;
520
+ }
521
+ minlbfgsresults(state, w, internalrep);
522
+ ap::vmove(&network.weights(0), &w(0), ap::vlen(0,wcount-1));
523
+
524
+ //
525
+ // Compare with best
526
+ //
527
+ v = ap::vdotproduct(&network.weights(0), &network.weights(0), ap::vlen(0,wcount-1));
528
+ e = mlperrorn(network, xy, npoints)+0.5*decay*v;
529
+ if( e<ebest )
530
+ {
531
+ ap::vmove(&wbest(0), &network.weights(0), ap::vlen(0,wcount-1));
532
+ ebest = e;
533
+ }
534
+ }
535
+
536
+ //
537
+ // The best network
538
+ //
539
+ ap::vmove(&network.weights(0), &wbest(0), ap::vlen(0,wcount-1));
540
+ }
541
+
542
+
543
+ /*************************************************************************
544
+ Neural network training using early stopping (base algorithm - L-BFGS with
545
+ regularization).
546
+
547
+ INPUT PARAMETERS:
548
+ Network - neural network with initialized geometry
549
+ TrnXY - training set
550
+ TrnSize - training set size
551
+ ValXY - validation set
552
+ ValSize - validation set size
553
+ Decay - weight decay constant, >=0.001
554
+ Decay term 'Decay*||Weights||^2' is added to error
555
+ function.
556
+ If you don't know what Decay to choose, use 0.001.
557
+ Restarts - number of restarts from random position, >0.
558
+ If you don't know what Restarts to choose, use 2.
559
+
560
+ OUTPUT PARAMETERS:
561
+ Network - trained neural network.
562
+ Info - return code:
563
+ * -2, if there is a point with class number
564
+ outside of [0..NOut-1].
565
+ * -1, if wrong parameters specified
566
+ (NPoints<0, Restarts<1, ...).
567
+ * 2, task has been solved, stopping criterion met -
568
+ sufficiently small step size. Not expected (we
569
+ use EARLY stopping) but possible and not an
570
+ error.
571
+ * 6, task has been solved, stopping criterion met -
572
+ increasing of validation set error.
573
+ Rep - training report
574
+
575
+ NOTE:
576
+
577
+ Algorithm stops if validation set error increases for a long enough or
578
+ step size is small enought (there are task where validation set may
579
+ decrease for eternity). In any case solution returned corresponds to the
580
+ minimum of validation set error.
581
+
582
+ -- ALGLIB --
583
+ Copyright 10.03.2009 by Bochkanov Sergey
584
+ *************************************************************************/
585
+ void mlptraines(multilayerperceptron& network,
586
+ const ap::real_2d_array& trnxy,
587
+ int trnsize,
588
+ const ap::real_2d_array& valxy,
589
+ int valsize,
590
+ double decay,
591
+ int restarts,
592
+ int& info,
593
+ mlpreport& rep)
594
+ {
595
+ int i;
596
+ int j;
597
+ int pass;
598
+ int nin;
599
+ int nout;
600
+ int wcount;
601
+ ap::real_1d_array w;
602
+ ap::real_1d_array wbest;
603
+ double e;
604
+ double v;
605
+ double ebest;
606
+ ap::real_1d_array wfinal;
607
+ double efinal;
608
+ int itbest;
609
+ lbfgsreport internalrep;
610
+ lbfgsstate state;
611
+ double wstep;
612
+
613
+ wstep = 0.001;
614
+
615
+ //
616
+ // Test inputs, parse flags, read network geometry
617
+ //
618
+ if( trnsize<=0||valsize<=0||restarts<1||decay<0 )
619
+ {
620
+ info = -1;
621
+ return;
622
+ }
623
+ mlpproperties(network, nin, nout, wcount);
624
+ if( mlpissoftmax(network) )
625
+ {
626
+ for(i = 0; i <= trnsize-1; i++)
627
+ {
628
+ if( ap::round(trnxy(i,nin))<0||ap::round(trnxy(i,nin))>=nout )
629
+ {
630
+ info = -2;
631
+ return;
632
+ }
633
+ }
634
+ for(i = 0; i <= valsize-1; i++)
635
+ {
636
+ if( ap::round(valxy(i,nin))<0||ap::round(valxy(i,nin))>=nout )
637
+ {
638
+ info = -2;
639
+ return;
640
+ }
641
+ }
642
+ }
643
+ info = 2;
644
+
645
+ //
646
+ // Prepare
647
+ //
648
+ mlpinitpreprocessor(network, trnxy, trnsize);
649
+ w.setbounds(0, wcount-1);
650
+ wbest.setbounds(0, wcount-1);
651
+ wfinal.setbounds(0, wcount-1);
652
+ efinal = ap::maxrealnumber;
653
+ for(i = 0; i <= wcount-1; i++)
654
+ {
655
+ wfinal(i) = 0;
656
+ }
657
+
658
+ //
659
+ // Multiple starts
660
+ //
661
+ rep.ncholesky = 0;
662
+ rep.nhess = 0;
663
+ rep.ngrad = 0;
664
+ for(pass = 1; pass <= restarts; pass++)
665
+ {
666
+
667
+ //
668
+ // Process
669
+ //
670
+ mlprandomize(network);
671
+ ebest = mlperror(network, valxy, valsize);
672
+ ap::vmove(&wbest(0), &network.weights(0), ap::vlen(0,wcount-1));
673
+ itbest = 0;
674
+ ap::vmove(&w(0), &network.weights(0), ap::vlen(0,wcount-1));
675
+ minlbfgs(wcount, ap::minint(wcount, 50), w, 0.0, 0.0, wstep, 0, 0, state);
676
+ while(minlbfgsiteration(state))
677
+ {
678
+
679
+ //
680
+ // Calculate gradient
681
+ //
682
+ ap::vmove(&network.weights(0), &state.x(0), ap::vlen(0,wcount-1));
683
+ mlpgradnbatch(network, trnxy, trnsize, state.f, state.g);
684
+ v = ap::vdotproduct(&network.weights(0), &network.weights(0), ap::vlen(0,wcount-1));
685
+ state.f = state.f+0.5*decay*v;
686
+ ap::vadd(&state.g(0), &network.weights(0), ap::vlen(0,wcount-1), decay);
687
+ rep.ngrad = rep.ngrad+1;
688
+
689
+ //
690
+ // Validation set
691
+ //
692
+ if( state.xupdated )
693
+ {
694
+ ap::vmove(&network.weights(0), &w(0), ap::vlen(0,wcount-1));
695
+ e = mlperror(network, valxy, valsize);
696
+ if( e<ebest )
697
+ {
698
+ ebest = e;
699
+ ap::vmove(&wbest(0), &network.weights(0), ap::vlen(0,wcount-1));
700
+ itbest = internalrep.iterationscount;
701
+ }
702
+ if( internalrep.iterationscount>30&&internalrep.iterationscount>1.5*itbest )
703
+ {
704
+ info = 6;
705
+ break;
706
+ }
707
+ }
708
+ }
709
+ minlbfgsresults(state, w, internalrep);
710
+
711
+ //
712
+ // Compare with final answer
713
+ //
714
+ if( ebest<efinal )
715
+ {
716
+ ap::vmove(&wfinal(0), &wbest(0), ap::vlen(0,wcount-1));
717
+ efinal = ebest;
718
+ }
719
+ }
720
+
721
+ //
722
+ // The best network
723
+ //
724
+ ap::vmove(&network.weights(0), &wfinal(0), ap::vlen(0,wcount-1));
725
+ }
726
+
727
+
728
+ /*************************************************************************
729
+ Cross-validation estimate of generalization error.
730
+
731
+ Base algorithm - L-BFGS.
732
+
733
+ INPUT PARAMETERS:
734
+ Network - neural network with initialized geometry. Network is
735
+ not changed during cross-validation - it is used only
736
+ as a representative of its architecture.
737
+ XY - training set.
738
+ SSize - training set size
739
+ Decay - weight decay, same as in MLPTrainLBFGS
740
+ Restarts - number of restarts, >0.
741
+ restarts are counted for each partition separately, so
742
+ total number of restarts will be Restarts*FoldsCount.
743
+ WStep - stopping criterion, same as in MLPTrainLBFGS
744
+ MaxIts - stopping criterion, same as in MLPTrainLBFGS
745
+ FoldsCount - number of folds in k-fold cross-validation,
746
+ 2<=FoldsCount<=SSize.
747
+ recommended value: 10.
748
+
749
+ OUTPUT PARAMETERS:
750
+ Info - return code, same as in MLPTrainLBFGS
751
+ Rep - report, same as in MLPTrainLM/MLPTrainLBFGS
752
+ CVRep - generalization error estimates
753
+
754
+ -- ALGLIB --
755
+ Copyright 09.12.2007 by Bochkanov Sergey
756
+ *************************************************************************/
757
+ void mlpkfoldcvlbfgs(const multilayerperceptron& network,
758
+ const ap::real_2d_array& xy,
759
+ int npoints,
760
+ double decay,
761
+ int restarts,
762
+ double wstep,
763
+ int maxits,
764
+ int foldscount,
765
+ int& info,
766
+ mlpreport& rep,
767
+ mlpcvreport& cvrep)
768
+ {
769
+
770
+ mlpkfoldcvgeneral(network, xy, npoints, decay, restarts, foldscount, false, wstep, maxits, info, rep, cvrep);
771
+ }
772
+
773
+
774
+ /*************************************************************************
775
+ Cross-validation estimate of generalization error.
776
+
777
+ Base algorithm - Levenberg-Marquardt.
778
+
779
+ INPUT PARAMETERS:
780
+ Network - neural network with initialized geometry. Network is
781
+ not changed during cross-validation - it is used only
782
+ as a representative of its architecture.
783
+ XY - training set.
784
+ SSize - training set size
785
+ Decay - weight decay, same as in MLPTrainLBFGS
786
+ Restarts - number of restarts, >0.
787
+ restarts are counted for each partition separately, so
788
+ total number of restarts will be Restarts*FoldsCount.
789
+ FoldsCount - number of folds in k-fold cross-validation,
790
+ 2<=FoldsCount<=SSize.
791
+ recommended value: 10.
792
+
793
+ OUTPUT PARAMETERS:
794
+ Info - return code, same as in MLPTrainLBFGS
795
+ Rep - report, same as in MLPTrainLM/MLPTrainLBFGS
796
+ CVRep - generalization error estimates
797
+
798
+ -- ALGLIB --
799
+ Copyright 09.12.2007 by Bochkanov Sergey
800
+ *************************************************************************/
801
+ void mlpkfoldcvlm(const multilayerperceptron& network,
802
+ const ap::real_2d_array& xy,
803
+ int npoints,
804
+ double decay,
805
+ int restarts,
806
+ int foldscount,
807
+ int& info,
808
+ mlpreport& rep,
809
+ mlpcvreport& cvrep)
810
+ {
811
+
812
+ mlpkfoldcvgeneral(network, xy, npoints, decay, restarts, foldscount, true, 0.0, 0, info, rep, cvrep);
813
+ }
814
+
815
+
816
+ /*************************************************************************
817
+ Internal cross-validation subroutine
818
+ *************************************************************************/
819
+ static void mlpkfoldcvgeneral(const multilayerperceptron& n,
820
+ const ap::real_2d_array& xy,
821
+ int npoints,
822
+ double decay,
823
+ int restarts,
824
+ int foldscount,
825
+ bool lmalgorithm,
826
+ double wstep,
827
+ int maxits,
828
+ int& info,
829
+ mlpreport& rep,
830
+ mlpcvreport& cvrep)
831
+ {
832
+ int i;
833
+ int fold;
834
+ int j;
835
+ int k;
836
+ multilayerperceptron network;
837
+ int nin;
838
+ int nout;
839
+ int rowlen;
840
+ int wcount;
841
+ int nclasses;
842
+ int tssize;
843
+ int cvssize;
844
+ ap::real_2d_array cvset;
845
+ ap::real_2d_array testset;
846
+ ap::integer_1d_array folds;
847
+ int relcnt;
848
+ mlpreport internalrep;
849
+ ap::real_1d_array x;
850
+ ap::real_1d_array y;
851
+
852
+
853
+ //
854
+ // Read network geometry, test parameters
855
+ //
856
+ mlpproperties(n, nin, nout, wcount);
857
+ if( mlpissoftmax(n) )
858
+ {
859
+ nclasses = nout;
860
+ rowlen = nin+1;
861
+ }
862
+ else
863
+ {
864
+ nclasses = -nout;
865
+ rowlen = nin+nout;
866
+ }
867
+ if( npoints<=0||foldscount<2||foldscount>npoints )
868
+ {
869
+ info = -1;
870
+ return;
871
+ }
872
+ mlpcopy(n, network);
873
+
874
+ //
875
+ // K-fold out cross-validation.
876
+ // First, estimate generalization error
877
+ //
878
+ testset.setbounds(0, npoints-1, 0, rowlen-1);
879
+ cvset.setbounds(0, npoints-1, 0, rowlen-1);
880
+ x.setbounds(0, nin-1);
881
+ y.setbounds(0, nout-1);
882
+ mlpkfoldsplit(xy, npoints, nclasses, foldscount, false, folds);
883
+ cvrep.relclserror = 0;
884
+ cvrep.avgce = 0;
885
+ cvrep.rmserror = 0;
886
+ cvrep.avgerror = 0;
887
+ cvrep.avgrelerror = 0;
888
+ rep.ngrad = 0;
889
+ rep.nhess = 0;
890
+ rep.ncholesky = 0;
891
+ relcnt = 0;
892
+ for(fold = 0; fold <= foldscount-1; fold++)
893
+ {
894
+
895
+ //
896
+ // Separate set
897
+ //
898
+ tssize = 0;
899
+ cvssize = 0;
900
+ for(i = 0; i <= npoints-1; i++)
901
+ {
902
+ if( folds(i)==fold )
903
+ {
904
+ ap::vmove(&testset(tssize, 0), &xy(i, 0), ap::vlen(0,rowlen-1));
905
+ tssize = tssize+1;
906
+ }
907
+ else
908
+ {
909
+ ap::vmove(&cvset(cvssize, 0), &xy(i, 0), ap::vlen(0,rowlen-1));
910
+ cvssize = cvssize+1;
911
+ }
912
+ }
913
+
914
+ //
915
+ // Train on CV training set
916
+ //
917
+ if( lmalgorithm )
918
+ {
919
+ mlptrainlm(network, cvset, cvssize, decay, restarts, info, internalrep);
920
+ }
921
+ else
922
+ {
923
+ mlptrainlbfgs(network, cvset, cvssize, decay, restarts, wstep, maxits, info, internalrep);
924
+ }
925
+ if( info<0 )
926
+ {
927
+ cvrep.relclserror = 0;
928
+ cvrep.avgce = 0;
929
+ cvrep.rmserror = 0;
930
+ cvrep.avgerror = 0;
931
+ cvrep.avgrelerror = 0;
932
+ return;
933
+ }
934
+ rep.ngrad = rep.ngrad+internalrep.ngrad;
935
+ rep.nhess = rep.nhess+internalrep.nhess;
936
+ rep.ncholesky = rep.ncholesky+internalrep.ncholesky;
937
+
938
+ //
939
+ // Estimate error using CV test set
940
+ //
941
+ if( mlpissoftmax(network) )
942
+ {
943
+
944
+ //
945
+ // classification-only code
946
+ //
947
+ cvrep.relclserror = cvrep.relclserror+mlpclserror(network, testset, tssize);
948
+ cvrep.avgce = cvrep.avgce+mlperrorn(network, testset, tssize);
949
+ }
950
+ for(i = 0; i <= tssize-1; i++)
951
+ {
952
+ ap::vmove(&x(0), &testset(i, 0), ap::vlen(0,nin-1));
953
+ mlpprocess(network, x, y);
954
+ if( mlpissoftmax(network) )
955
+ {
956
+
957
+ //
958
+ // Classification-specific code
959
+ //
960
+ k = ap::round(testset(i,nin));
961
+ for(j = 0; j <= nout-1; j++)
962
+ {
963
+ if( j==k )
964
+ {
965
+ cvrep.rmserror = cvrep.rmserror+ap::sqr(y(j)-1);
966
+ cvrep.avgerror = cvrep.avgerror+fabs(y(j)-1);
967
+ cvrep.avgrelerror = cvrep.avgrelerror+fabs(y(j)-1);
968
+ relcnt = relcnt+1;
969
+ }
970
+ else
971
+ {
972
+ cvrep.rmserror = cvrep.rmserror+ap::sqr(y(j));
973
+ cvrep.avgerror = cvrep.avgerror+fabs(y(j));
974
+ }
975
+ }
976
+ }
977
+ else
978
+ {
979
+
980
+ //
981
+ // Regression-specific code
982
+ //
983
+ for(j = 0; j <= nout-1; j++)
984
+ {
985
+ cvrep.rmserror = cvrep.rmserror+ap::sqr(y(j)-testset(i,nin+j));
986
+ cvrep.avgerror = cvrep.avgerror+fabs(y(j)-testset(i,nin+j));
987
+ if( testset(i,nin+j)!=0 )
988
+ {
989
+ cvrep.avgrelerror = cvrep.avgrelerror+fabs((y(j)-testset(i,nin+j))/testset(i,nin+j));
990
+ relcnt = relcnt+1;
991
+ }
992
+ }
993
+ }
994
+ }
995
+ }
996
+ if( mlpissoftmax(network) )
997
+ {
998
+ cvrep.relclserror = cvrep.relclserror/npoints;
999
+ cvrep.avgce = cvrep.avgce/(log(double(2))*npoints);
1000
+ }
1001
+ cvrep.rmserror = sqrt(cvrep.rmserror/(npoints*nout));
1002
+ cvrep.avgerror = cvrep.avgerror/(npoints*nout);
1003
+ cvrep.avgrelerror = cvrep.avgrelerror/relcnt;
1004
+ info = 1;
1005
+ }
1006
+
1007
+
1008
+ /*************************************************************************
1009
+ Subroutine prepares K-fold split of the training set.
1010
+
1011
+ NOTES:
1012
+ "NClasses>0" means that we have classification task.
1013
+ "NClasses<0" means regression task with -NClasses real outputs.
1014
+ *************************************************************************/
1015
+ static void mlpkfoldsplit(const ap::real_2d_array& xy,
1016
+ int npoints,
1017
+ int nclasses,
1018
+ int foldscount,
1019
+ bool stratifiedsplits,
1020
+ ap::integer_1d_array& folds)
1021
+ {
1022
+ int i;
1023
+ int j;
1024
+ int k;
1025
+
1026
+
1027
+ //
1028
+ // test parameters
1029
+ //
1030
+ ap::ap_error::make_assertion(npoints>0, "MLPKFoldSplit: wrong NPoints!");
1031
+ ap::ap_error::make_assertion(nclasses>1||nclasses<0, "MLPKFoldSplit: wrong NClasses!");
1032
+ ap::ap_error::make_assertion(foldscount>=2&&foldscount<=npoints, "MLPKFoldSplit: wrong FoldsCount!");
1033
+ ap::ap_error::make_assertion(!stratifiedsplits, "MLPKFoldSplit: stratified splits are not supported!");
1034
+
1035
+ //
1036
+ // Folds
1037
+ //
1038
+ folds.setbounds(0, npoints-1);
1039
+ for(i = 0; i <= npoints-1; i++)
1040
+ {
1041
+ folds(i) = i*foldscount/npoints;
1042
+ }
1043
+ for(i = 0; i <= npoints-2; i++)
1044
+ {
1045
+ j = i+ap::randominteger(npoints-i);
1046
+ if( j!=i )
1047
+ {
1048
+ k = folds(i);
1049
+ folds(i) = folds(j);
1050
+ folds(j) = k;
1051
+ }
1052
+ }
1053
+ }
1054
+
1055
+
1056
+