ml4r 0.1.4 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. data/ext/ml4r/LinearRegression/LinearRegression.cpp +305 -0
  2. data/ext/ml4r/LinearRegression/OLSLinearRegression.cpp +75 -0
  3. data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeExperiment.cpp +50 -0
  4. data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeNode.cpp +195 -0
  5. data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitter.cpp +551 -0
  6. data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterCategorical.cpp +22 -0
  7. data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterContinuous.cpp +21 -0
  8. data/ext/ml4r/MachineLearning/DecisionTree/SplitDefinition.cpp +142 -0
  9. data/ext/ml4r/MachineLearning/GBM/BernoulliCalculator.cpp +95 -0
  10. data/ext/ml4r/MachineLearning/GBM/GBMEstimator.cpp +601 -0
  11. data/ext/ml4r/MachineLearning/GBM/GBMOutput.cpp +86 -0
  12. data/ext/ml4r/MachineLearning/GBM/GBMRunner.cpp +117 -0
  13. data/ext/ml4r/MachineLearning/GBM/GaussianCalculator.cpp +94 -0
  14. data/ext/ml4r/MachineLearning/GBM/ZenithGBM.cpp +317 -0
  15. data/ext/ml4r/MachineLearning/MLData/MLData.cpp +232 -0
  16. data/ext/ml4r/MachineLearning/MLData/MLDataFields.cpp +1 -0
  17. data/ext/ml4r/MachineLearning/MLData/MLDataReader.cpp +139 -0
  18. data/ext/ml4r/MachineLearning/MLData/ZenithMLData.cpp +96 -0
  19. data/ext/ml4r/MachineLearning/MLData/ZenithMLDataReader.cpp +113 -0
  20. data/ext/ml4r/MachineLearning/MLExperiment.cpp +69 -0
  21. data/ext/ml4r/MachineLearning/MLRunner.cpp +183 -0
  22. data/ext/ml4r/MachineLearning/MLUtils.cpp +15 -0
  23. data/ext/ml4r/MachineLearning/RandomForest/RandomForestEstimator.cpp +172 -0
  24. data/ext/ml4r/MachineLearning/RandomForest/RandomForestOutput.cpp +66 -0
  25. data/ext/ml4r/MachineLearning/RandomForest/RandomForestRunner.cpp +84 -0
  26. data/ext/ml4r/MachineLearning/RandomForest/ZenithRandomForest.cpp +184 -0
  27. data/ext/ml4r/ml4r.cpp +34 -0
  28. data/ext/ml4r/ml4r_wrap.cpp +15727 -0
  29. data/ext/ml4r/utils/MathUtils.cpp +204 -0
  30. data/ext/ml4r/utils/StochasticUtils.cpp +73 -0
  31. data/ext/ml4r/utils/Utils.cpp +14 -0
  32. data/ext/ml4r/utils/VlcMessage.cpp +3 -0
  33. metadata +33 -1
@@ -0,0 +1,305 @@
1
+ #include <boost/numeric/ublas/vector.hpp>
2
+ #include <boost/numeric/ublas/vector_proxy.hpp>
3
+ #include <boost/numeric/ublas/matrix.hpp>
4
+ #include <boost/numeric/ublas/triangular.hpp>
5
+ #include <boost/numeric/ublas/lu.hpp>
6
+ #include <boost/numeric/ublas/io.hpp>
7
+ #include <boost/foreach.hpp>
8
+ #include <iostream>
9
+ using std::cout;
10
+ using std::endl;
11
+
12
+ #include "LinearRegression/LinearRegression.h"
13
+ #include "utils/MatrixInversion.h"
14
+ #include "utils/Utils.h"
15
+ namespace ublas = boost::numeric::ublas;
16
+
17
+ using std::vector;
18
+ using ublas::prod;
19
+ using ublas::matrix;
20
+
21
+
22
+ void LinearRegression::setWeights(vector<double> weights)
23
+ {
24
+ m_ws = weights;
25
+ }
26
+
27
+ void LinearRegression::setFixedConstant(double val)
28
+ {
29
+ m_constant = val;
30
+ m_constantIsFixed = true;
31
+ }
32
+
33
+ pair<vector<double>,double> LinearRegression::getParameterEstimates()
34
+ {
35
+ return make_pair(m_bs,m_constant);
36
+ }
37
+
38
+ void LinearRegression::checkDimensions()
39
+ {
40
+ if (!m_ys.size())
41
+ throw std::runtime_error("[LinearRegression] Number of observations equals zero");
42
+
43
+ if (m_xs.size() != m_ys.size())
44
+ throw std::runtime_error("[LinearRegression] Number of observations in x doesn't match number of observations in y");
45
+
46
+ if (m_ws.size() && m_ws.size() != m_ys.size())
47
+ throw std::runtime_error("[LinearRegression] Number of specified weights doesn't match number of observations");
48
+
49
+ unsigned long dimensionOfX = m_xs.front().size();
50
+ BOOST_FOREACH(vector<double>& x, m_xs)
51
+ if (x.size() != dimensionOfX)
52
+ throw std::runtime_error("[LinearRegression] Dimensions of x variables are inconsistent between observations");
53
+ }
54
+
55
+ void LinearRegression::calculateStatistics()
56
+ {
57
+ if (!m_paramsAreValid)
58
+ throw std::runtime_error("[LinearRegression] Parameters have not been estimated");
59
+
60
+ calculateModelStatistics();
61
+ calculateParameterStatistics();
62
+ }
63
+
64
+ void LinearRegression::calculateParameterStatistics2()
65
+ {
66
+ // form the matrix X'X
67
+ ublas::matrix<double> X(m_xs.size(), m_xs.front().size()+1);
68
+ ublas::matrix<double>::iterator2 matrixIterator = X.begin2();
69
+ BOOST_FOREACH(vector<double>& row, m_xs)
70
+ {
71
+ matrixIterator = std::copy(row.begin(), row.end(), matrixIterator);
72
+ *(matrixIterator++) = 1.0;
73
+ }
74
+ ublas::matrix<double> X_transpose_X = ublas::prod(ublas::trans(X), X);
75
+
76
+ // Invert the matrix
77
+ ublas::matrix<double> X_transpose_X_inverse(X_transpose_X);
78
+ InvertMatrix(X_transpose_X, X_transpose_X_inverse);
79
+
80
+ // Also construct a t-stat for the constant
81
+ if (!m_constantIsFixed) m_bs.push_back(m_constant);
82
+
83
+ m_tStatistics.resize(m_bs.size());
84
+ for (unsigned int i=0; i<m_bs.size(); ++i)
85
+ {
86
+ m_tStatistics.at(i) = m_bs.at(i) / (m_s * sqrt(X_transpose_X_inverse(i,i)));
87
+ }
88
+
89
+ if (!m_constantIsFixed) m_bs.pop_back();
90
+ }
91
+
92
+ void LinearRegression::calculateModelStatistics()
93
+ {
94
+ checkDimensions();
95
+ checkParametersAreEstimated();
96
+ estimateYs();
97
+
98
+ double meanY = Utils::vectorSum(m_ys) / m_n;
99
+ double sumSquaresTotal = 0.0;
100
+ double sumSquaresRegression = 0.0;
101
+ double sumSquaresError = 0.0;
102
+ double meanWeight = Utils::vectorSum(m_ws) / m_n;
103
+ for (int i=0; i<m_n; ++i)
104
+ {
105
+ sumSquaresTotal += m_ws.at(i) / meanWeight * pow(m_ys.at(i) - meanY, 2.0);
106
+ sumSquaresRegression += m_ws.at(i) / meanWeight * pow(m_fittedYs.at(i) - meanY, 2.0);
107
+ sumSquaresError += m_ws.at(i) / meanWeight * pow(m_ys.at(i) - m_fittedYs.at(i), 2.0);
108
+ }
109
+
110
+ double meanSquaredRegression = sumSquaresRegression / (m_k);
111
+
112
+ m_rSquared = 1.0 - (sumSquaresError / sumSquaresTotal);
113
+ m_adjustedRSquared = 1.0 - (sumSquaresError / (m_n - m_p)) / (sumSquaresTotal / (m_n - 1));
114
+ m_fStatistic = (m_n-m_p) * sumSquaresRegression / (sumSquaresError * m_k);
115
+ m_sSquared = 1.0 / (m_n-m_p) * sumSquaresError;
116
+ m_s = sqrt(m_sSquared);
117
+
118
+ m_h_diagonal.resize(m_n, 0.0);
119
+ // auto XIterator = m_X.begin2(); // row-wise
120
+ // auto AIterator = m_A.begin1(); // column-wise
121
+ for (int i = 0; i < m_n; ++i)
122
+ {
123
+ double sumProduct = 0.0;
124
+ for (int j = 0; j < m_p; ++j)
125
+ sumProduct += m_X(i, j) * m_A(j, i);
126
+ m_h_diagonal.at(i) = sumProduct;
127
+ }
128
+
129
+ m_pressStatistic = 0.0;
130
+ m_presarStatistic = 0.0;
131
+
132
+ m_predictedYs.resize(m_n);
133
+ for (int i = 0; i < m_n; ++i)
134
+ {
135
+ double ei = m_fittedYs.at(i) - m_ys.at(i);
136
+ double hii = m_h_diagonal.at(i);
137
+ double ei_prediction = ei / (1.0 - hii); // best thing eva!!!
138
+ m_predictedYs.at(i) = m_ys.at(i) + ei_prediction;
139
+ m_presarStatistic += m_ws.at(i) / meanWeight * abs((float)ei_prediction);
140
+ m_pressStatistic += m_ws.at(i) / meanWeight * pow(ei_prediction, 2.0);
141
+ }
142
+ m_rSquaredPrediction = 1.0 - m_pressStatistic / sumSquaresTotal;
143
+ }
144
+
145
+ void LinearRegression::estimateYs()
146
+ {
147
+ m_fittedYs.clear();
148
+ m_fittedYs.resize(m_ys.size(), m_constant);
149
+ for (unsigned int i=0; i<m_ys.size(); ++i)
150
+ {
151
+ for (unsigned int j=0; j<m_bs.size(); ++j)
152
+ m_fittedYs.at(i) += m_bs.at(j) * m_xs.at(i).at(j);
153
+ }
154
+ }
155
+
156
+ void LinearRegression::checkParametersAreEstimated()
157
+ {
158
+ if (!m_paramsAreValid)
159
+ throw std::runtime_error("[LinearRegression] Parameters have not been estimated");
160
+ }
161
+
162
+ double LinearRegression::getRSquared()
163
+ {
164
+ return m_rSquared;
165
+ }
166
+
167
+ double LinearRegression::getFstatistic()
168
+ {
169
+ return m_fStatistic;
170
+ }
171
+
172
+ vector<double>& LinearRegression::getFittedYs()
173
+ {
174
+ return m_fittedYs;
175
+ }
176
+
177
+ vector<double>& LinearRegression::getTstatistics()
178
+ {
179
+ return m_tStatistics;
180
+ }
181
+
182
+ void LinearRegression::populateMembers()
183
+ {
184
+ m_k = m_xs.front().size();
185
+ m_p = m_k + (m_constantIsFixed ? 0 : 1);
186
+ m_n = m_xs.size();
187
+
188
+ // populate m_X
189
+ m_X.resize(m_n, m_p);
190
+ ublas::matrix<double>::iterator2 matrixIterator = m_X.begin2();
191
+ BOOST_FOREACH(vector<double>& row, m_xs)
192
+ {
193
+ matrixIterator = std::copy(row.begin(), row.end(), matrixIterator);
194
+ if (!m_constantIsFixed) *(matrixIterator++) = 1.0;
195
+ }
196
+
197
+ // populate m_Y
198
+ m_Y.resize(m_n, 1);
199
+ ublas::matrix<double>::iterator1 matrixIterator2 = m_Y.begin1();
200
+ BOOST_FOREACH(double& y, m_ys)
201
+ {
202
+ (*matrixIterator2) = y;
203
+ ++matrixIterator2;
204
+ }
205
+
206
+ // populate m_ws with 1's if it's not already defined
207
+ if (!m_ws.size())
208
+ {
209
+ m_ws.resize(m_n, 1.0);
210
+ }
211
+
212
+ // form the matrix X' [P x N]
213
+ m_Xtranspose = ublas::trans(m_X);
214
+
215
+ // form the matrix X'WX [P x N] . [N x N] . [N x P] => [P x P]
216
+ m_Xtranspose_W_X.resize(m_p, m_p);
217
+ m_Xtranspose_W = multiplyMatrixByWeights(m_Xtranspose);
218
+ m_Xtranspose_W_X = ublas::prod(m_Xtranspose_W, m_X);
219
+
220
+ // Invert the matrix
221
+ m_Xtranspose_W_X_inverse.resize(m_p, m_p);
222
+ InvertMatrix(m_Xtranspose_W_X, m_Xtranspose_W_X_inverse);
223
+ }
224
+
225
+ void LinearRegression::calculateParameterStatistics()
226
+ {
227
+ m_tStatistics.resize(m_p);
228
+ m_standardErrors.resize(m_p);
229
+
230
+ ublas::matrix<double> AAt = prod(m_A, ublas::trans(m_A));
231
+ for (int i=0; i<m_p; ++i)
232
+ {
233
+ // made more complicated by weights!!!
234
+ m_standardErrors.at(i) = m_s * sqrt(AAt(i,i));
235
+ m_tStatistics.at(i) = m_B(i,0) / m_standardErrors.at(i);
236
+ }
237
+ }
238
+
239
+ double LinearRegression::getPressStatistic()
240
+ {
241
+ return m_pressStatistic;
242
+ }
243
+
244
+ double LinearRegression::getPresarStatistic()
245
+ {
246
+ return m_presarStatistic;
247
+ }
248
+
249
+ double LinearRegression::getRSquaredPrediction()
250
+ {
251
+ return m_rSquaredPrediction;
252
+ }
253
+
254
+ vector<double>& LinearRegression::getPredictedYs()
255
+ {
256
+ return m_predictedYs;
257
+ }
258
+
259
+ double LinearRegression::getAdjustedRSquared()
260
+ {
261
+ return m_adjustedRSquared;
262
+ }
263
+
264
+ matrix<double> LinearRegression::multiplyMatrixByWeights(matrix<double>& mat)
265
+ {
266
+ if (mat.size2() != m_ws.size())
267
+ throw std::runtime_error("[LinearRegression::multiplyMatrixByWeights] invalid matrix dimensions!");
268
+
269
+ matrix<double> new_matrix = mat; // copy
270
+ for (unsigned int j = 0; j < new_matrix.size2(); ++j) // each column
271
+ {
272
+ double weight = m_ws.at(j);
273
+ for (unsigned int i = 0; i < new_matrix.size1(); ++i) // each row
274
+ new_matrix(i,j) *= weight;
275
+ }
276
+ return new_matrix;
277
+ }
278
+
279
+ matrix<double> LinearRegression::multiplyWeightsByMatrix(matrix<double>& mat)
280
+ {
281
+ if (mat.size1() != m_ws.size())
282
+ throw std::runtime_error("[LinearRegression::multiplyMatrixByWeights] invalid matrix dimensions!");
283
+
284
+ matrix<double> new_matrix = mat; // copy
285
+ for (unsigned int i = 0; i < new_matrix.size2(); ++i) // each row
286
+ {
287
+ double weight = m_ws.at(i);
288
+ for (unsigned int j = 0; j < new_matrix.size1(); ++j) // each column
289
+ new_matrix(i,j) *= weight;
290
+ }
291
+ return new_matrix;
292
+ }
293
+
294
+ vector<double>& LinearRegression::getStandardErrors()
295
+ {
296
+ return m_standardErrors;
297
+ }
298
+
299
+ double LinearRegression::getSSquared()
300
+ {
301
+ return m_sSquared;
302
+ }
303
+
304
+
305
+
@@ -0,0 +1,75 @@
1
+ #include "LinearRegression/OLSLinearRegression.h"
2
+ #include "utils/MathUtils.h"
3
+ #include "utils/Utils.h"
4
+
5
+ #include <iostream>
6
+ #include <boost/numeric/ublas/io.hpp>
7
+ using std::cout;
8
+ using std::endl;
9
+
10
+ namespace ublas = boost::numeric::ublas;
11
+ using Utils::operator+=;
12
+ using ublas::matrix;
13
+ using ublas::prod;
14
+
15
+ OLSLinearRegression::OLSLinearRegression(std::vector<double> xs, std::vector<double> ys,
16
+ std::vector<double> weights)
17
+ : LinearRegression(xs, ys, weights)
18
+ {
19
+ calculate();
20
+ }
21
+
22
+ OLSLinearRegression::OLSLinearRegression(std::vector<std::vector<double> > xs, std::vector<double> ys,
23
+ std::vector<double> weights)
24
+ : LinearRegression(xs, ys, weights)
25
+ {
26
+ calculate();
27
+ }
28
+
29
+ OLSLinearRegression::OLSLinearRegression(std::vector<std::vector<double> > xs, std::vector<double> ys,
30
+ double fixedConstant, std::vector<double> weights)
31
+ : LinearRegression(xs, ys, fixedConstant, weights)
32
+ {
33
+ calculate();
34
+ }
35
+
36
+ OLSLinearRegression::~OLSLinearRegression()
37
+ {}
38
+
39
+ void OLSLinearRegression::calculate()
40
+ {
41
+ checkDimensions();
42
+ // matrix based implementation
43
+
44
+ // b = inverse(X'WX)X'Wy
45
+ // where X is the data matrix (rows are observations, columns are our X variables. If a constant is to be estimated,
46
+ // then the first column is set to 1, and the first estimated parameter will be the constant).
47
+ // X' is the transpose of X
48
+ // W is a matrix with diag(w1, w2, w3) etc, where wi is the weight of observation i
49
+ // y is the column matrix containing the observed y's.
50
+ populateMembers();
51
+ EstimateBs();
52
+ if (m_paramsAreValid) calculateStatistics();
53
+ }
54
+
55
+ void OLSLinearRegression::EstimateBs()
56
+ {
57
+ matrix<double> Y = m_Y;
58
+ if (m_constantIsFixed)
59
+ {
60
+ for (int i = 0; i < m_n; ++i) Y(i, 0) -= m_constant;
61
+ }
62
+
63
+ m_A = prod(m_Xtranspose_W_X_inverse, m_Xtranspose_W);
64
+ m_B = prod(m_A, Y);
65
+
66
+ // set m_bs and constant
67
+ m_bs.resize(m_k);
68
+ for (int i = 0; i < m_k; ++i)
69
+ m_bs.at(i) = m_B(i, 0);
70
+
71
+ if (!m_constantIsFixed)
72
+ m_constant = m_B(m_p-1, 0);
73
+
74
+ m_paramsAreValid = true;
75
+ }
@@ -0,0 +1,50 @@
1
+ #include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
2
+
3
+ DecisionTreeExperiment::DecisionTreeExperiment()
4
+ : MLExperiment()
5
+ {
6
+
7
+ }
8
+
9
+ DecisionTreeExperiment::DecisionTreeExperiment(shared_ptr<MLExperiment> mlExperiment)
10
+ : MLExperiment(mlExperiment)
11
+ {
12
+
13
+ }
14
+
15
+
16
+
17
+ DecisionTreeExperiment::~DecisionTreeExperiment()
18
+ {
19
+
20
+ }
21
+
22
+ void DecisionTreeExperiment::setF(double f)
23
+ {
24
+ m_F = f;
25
+ }
26
+
27
+ void DecisionTreeExperiment::setZ(double z)
28
+ {
29
+ m_Z = z;
30
+ }
31
+
32
+ double DecisionTreeExperiment::getF()
33
+ {
34
+ return m_F;
35
+ }
36
+
37
+ double DecisionTreeExperiment::getZ()
38
+ {
39
+ return m_Z;
40
+ }
41
+
42
+ void DecisionTreeExperiment::incrementF(double increment)
43
+ {
44
+ m_F += increment;
45
+ }
46
+
47
+ double DecisionTreeExperiment::getY()
48
+ {
49
+ return m_yValue;
50
+ }
@@ -0,0 +1,195 @@
1
+ #include "MachineLearning/DecisionTree/DecisionTreeNode.h"
2
+ #include "MachineLearning/DecisionTree/SplitDefinition.h"
3
+ #include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
4
+ #include "utils/Utils.h"
5
+
6
+ #include <stdexcept>
7
+ using std::runtime_error;
8
+
9
+ bool DecisionTreeNode::m_missingValueDefined = false;
10
+ double DecisionTreeNode::m_missingValue = -1.0;
11
+
12
+ DecisionTreeNode::DecisionTreeNode( vector<shared_ptr<DecisionTreeExperiment> > experiments,
13
+ double sumZ,
14
+ double sumW,
15
+ Partition partition,
16
+ shared_ptr<SplitDefinition> parentSplitDefinition)
17
+ : m_experiments(experiments), m_nodeHasChildren(false), m_sumZ(sumZ), m_sumW(sumW),
18
+ m_whichPartitionAmI(partition), m_parentSplitDefinition(parentSplitDefinition)
19
+ {
20
+
21
+ }
22
+
23
+ DecisionTreeNode::~DecisionTreeNode()
24
+ {
25
+
26
+ }
27
+
28
+ shared_ptr<DecisionTreeNode> DecisionTreeNode::getTerminalNodeForExperiment(shared_ptr<DecisionTreeExperiment> experiment)
29
+ {
30
+ if (!m_nodeHasChildren)
31
+ throw std::runtime_error("Node is a terminal node, so you shouldn't ask it for a terminal node!");
32
+
33
+ if (m_splitDefinition.get() == 0)
34
+ throw std::runtime_error("Node has children, but split definition is empty");
35
+
36
+ shared_ptr<DecisionTreeNode> childForExperiment = chooseChild(experiment);
37
+
38
+ if (childForExperiment.get() == 0)
39
+ return childForExperiment;
40
+ else if (childForExperiment->getSumW() == 0)
41
+ {
42
+ // this likely means that the value is missing, but there weren't any missing values in the
43
+ // bagged training set. Therefore, there is no weight in the missing child.
44
+ // return an empty pointer, and this DecisionTreeNode will become the one chosen.
45
+ return shared_ptr<DecisionTreeNode>();
46
+ }
47
+ else if (childForExperiment->isTerminalNode())
48
+ return childForExperiment;
49
+ else
50
+ {
51
+ shared_ptr<DecisionTreeNode> terminalNode = childForExperiment->getTerminalNodeForExperiment(experiment);
52
+ if (terminalNode.get() == 0)
53
+ {
54
+ // we have encountered a NEW category...therefore we couldn't split on childForExperiment.
55
+ // therefore, return the child itself.
56
+ return childForExperiment;
57
+ }
58
+ else
59
+ return terminalNode;
60
+ }
61
+
62
+ }
63
+
64
+ shared_ptr<DecisionTreeNode> DecisionTreeNode::chooseChild(shared_ptr<DecisionTreeExperiment> experiment)
65
+ {
66
+ if (!m_nodeHasChildren)
67
+ throw std::runtime_error("[DecisionTreeNode::chooseChild] - this Decision Tree has no children!");
68
+
69
+ double featureValue = experiment->getFeatureValue(m_splitDefinition->getFeatureIndex());
70
+
71
+ if (m_missingValueDefined && m_missingValue == featureValue)
72
+ return m_missingChild;
73
+
74
+ if (m_splitDefinition->isCategorical()) // categorical variable
75
+ {
76
+ if (Utils::hasElement(m_splitDefinition->getLhsCategories(), featureValue))
77
+ return m_lhsChild;
78
+ else if (Utils::hasElement(m_splitDefinition->getRhsCategories(), featureValue))
79
+ return m_rhsChild;
80
+ else
81
+ {
82
+ // it's not missing, but not in left or right. Therefore, we have a NEW category.
83
+ // We should return an empty pointer, and let the parent handle it.
84
+ return shared_ptr<DecisionTreeNode>();
85
+ }
86
+ }
87
+ else // continuous variable
88
+ {
89
+ double splitValue = m_splitDefinition->getSplitValue();
90
+ if (m_missingValueDefined && m_missingValue == splitValue)
91
+ {
92
+ // complicated logic. Our split value equals the missing value. Therefore, we split off missing versus
93
+ // everything else (which gets put in the rhsChild). As our feature value is not the missing value, we choose
94
+ // the rhsChild.
95
+ return m_rhsChild;
96
+ }
97
+ else if (featureValue < splitValue)
98
+ return m_lhsChild;
99
+ else
100
+ return m_rhsChild;
101
+ }
102
+ }
103
+
104
+ void DecisionTreeNode::defineSplit( shared_ptr<SplitDefinition> splitDefinition,
105
+ shared_ptr<DecisionTreeNode> lhsChild,
106
+ shared_ptr<DecisionTreeNode> rhsChild,
107
+ shared_ptr<DecisionTreeNode> missingChild)
108
+ {
109
+ setChildren(lhsChild, rhsChild, missingChild);
110
+ m_splitDefinition = splitDefinition;
111
+ }
112
+
113
+ void DecisionTreeNode::setChildren( shared_ptr<DecisionTreeNode> lhsChild,
114
+ shared_ptr<DecisionTreeNode> rhsChild,
115
+ shared_ptr<DecisionTreeNode> missingChild)
116
+ {
117
+ m_nodeHasChildren = true;
118
+ m_lhsChild = lhsChild;
119
+ m_rhsChild = rhsChild;
120
+ m_missingChild = missingChild;
121
+ }
122
+
123
+ vector<shared_ptr<DecisionTreeExperiment> > DecisionTreeNode::getExperiments()
124
+ {
125
+ return m_experiments;
126
+ }
127
+
128
+ bool DecisionTreeNode::isTerminalNode()
129
+ {
130
+ return !m_nodeHasChildren;
131
+ }
132
+
133
+ void DecisionTreeNode::clearExperimentsWithinTree()
134
+ {
135
+ m_experiments.clear();
136
+ if (m_nodeHasChildren)
137
+ {
138
+ m_lhsChild->clearExperimentsWithinTree();
139
+ m_rhsChild->clearExperimentsWithinTree();
140
+ m_missingChild->clearExperimentsWithinTree();
141
+ }
142
+ }
143
+
144
+ double DecisionTreeNode::getSumZ()
145
+ {
146
+ return m_sumZ;
147
+ }
148
+
149
+ double DecisionTreeNode::getSumW()
150
+ {
151
+ return m_sumW;
152
+ }
153
+
154
+ void DecisionTreeNode::setMissingValue( double missingValue )
155
+ {
156
+ m_missingValue = missingValue;
157
+ m_missingValueDefined = true;
158
+ }
159
+
160
+ shared_ptr<SplitDefinition> DecisionTreeNode::getSplitDefinition()
161
+ {
162
+ return m_splitDefinition;
163
+ }
164
+
165
+ shared_ptr<SplitDefinition> DecisionTreeNode::getParentSplitDefinition()
166
+ {
167
+ return m_parentSplitDefinition;
168
+ }
169
+
170
+ Partition DecisionTreeNode::getPartition()
171
+ {
172
+ return m_whichPartitionAmI;
173
+ }
174
+
175
+ void DecisionTreeNode::setSumZ( double sumZ )
176
+ {
177
+ m_sumZ = sumZ;
178
+ }
179
+
180
+ void DecisionTreeNode::setSumW( double sumW )
181
+ {
182
+ m_sumW = sumW;
183
+ }
184
+
185
+ void DecisionTreeNode::updateSums()
186
+ {
187
+ m_sumW = 0.0;
188
+ m_sumZ = 0.0;
189
+ for (unsigned int i=0; i<m_experiments.size(); ++i)
190
+ {
191
+ double w = m_experiments.at(i)->getWeight();
192
+ m_sumW += w;
193
+ m_sumZ += w * m_experiments.at(i)->getZ();
194
+ }
195
+ }