ml4r 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/ext/ml4r/LinearRegression/LinearRegression.cpp +305 -0
- data/ext/ml4r/LinearRegression/OLSLinearRegression.cpp +75 -0
- data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeExperiment.cpp +50 -0
- data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeNode.cpp +195 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitter.cpp +551 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterCategorical.cpp +22 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterContinuous.cpp +21 -0
- data/ext/ml4r/MachineLearning/DecisionTree/SplitDefinition.cpp +142 -0
- data/ext/ml4r/MachineLearning/GBM/BernoulliCalculator.cpp +95 -0
- data/ext/ml4r/MachineLearning/GBM/GBMEstimator.cpp +601 -0
- data/ext/ml4r/MachineLearning/GBM/GBMOutput.cpp +86 -0
- data/ext/ml4r/MachineLearning/GBM/GBMRunner.cpp +117 -0
- data/ext/ml4r/MachineLearning/GBM/GaussianCalculator.cpp +94 -0
- data/ext/ml4r/MachineLearning/GBM/ZenithGBM.cpp +317 -0
- data/ext/ml4r/MachineLearning/MLData/MLData.cpp +232 -0
- data/ext/ml4r/MachineLearning/MLData/MLDataFields.cpp +1 -0
- data/ext/ml4r/MachineLearning/MLData/MLDataReader.cpp +139 -0
- data/ext/ml4r/MachineLearning/MLData/ZenithMLData.cpp +96 -0
- data/ext/ml4r/MachineLearning/MLData/ZenithMLDataReader.cpp +113 -0
- data/ext/ml4r/MachineLearning/MLExperiment.cpp +69 -0
- data/ext/ml4r/MachineLearning/MLRunner.cpp +183 -0
- data/ext/ml4r/MachineLearning/MLUtils.cpp +15 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestEstimator.cpp +172 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestOutput.cpp +66 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestRunner.cpp +84 -0
- data/ext/ml4r/MachineLearning/RandomForest/ZenithRandomForest.cpp +184 -0
- data/ext/ml4r/ml4r.cpp +34 -0
- data/ext/ml4r/ml4r_wrap.cpp +15727 -0
- data/ext/ml4r/utils/MathUtils.cpp +204 -0
- data/ext/ml4r/utils/StochasticUtils.cpp +73 -0
- data/ext/ml4r/utils/Utils.cpp +14 -0
- data/ext/ml4r/utils/VlcMessage.cpp +3 -0
- metadata +33 -1
@@ -0,0 +1,551 @@
|
|
1
|
+
#include "MachineLearning/DecisionTree/NodeSplitter.h"
|
2
|
+
#include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
|
3
|
+
#include "MachineLearning/DecisionTree/SplitDefinition.h"
|
4
|
+
#include "MachineLearning/DecisionTree/DecisionTreeNode.h"
|
5
|
+
#include "MachineLearning/DecisionTree/CategoryInfo.h"
|
6
|
+
#include "MachineLearning/MLData/MLData.h"
|
7
|
+
#include "MachineLearning/GBM/GBMEstimator.h"
|
8
|
+
#include "utils/Utils.h"
|
9
|
+
#include "utils/StochasticUtils.h"
|
10
|
+
|
11
|
+
#include <boost/foreach.hpp>
|
12
|
+
#include <boost/lexical_cast.hpp>
|
13
|
+
#include <cmath>
|
14
|
+
using boost::lexical_cast;
|
15
|
+
|
16
|
+
NodeSplitter::NodeSplitter(MLData* data, int minObservations, double scale)
|
17
|
+
: m_data(data), m_minObservations(minObservations),m_scale(scale)
|
18
|
+
{
|
19
|
+
m_missingValueDefined = m_data->missingValueDefined();
|
20
|
+
if (m_missingValueDefined)
|
21
|
+
m_missingValue = m_data->getMissingValue();
|
22
|
+
}
|
23
|
+
|
24
|
+
NodeSplitter::~NodeSplitter() {}
|
25
|
+
|
26
|
+
double NodeSplitter::calculateImprovement(double lhsSumW, double lhsSumZ, double rhsSumW, double rhsSumZ, double missingSumW, double missingSumZ)
|
27
|
+
{
|
28
|
+
double improvement = 0.0;
|
29
|
+
|
30
|
+
if (missingSumW == 0)
|
31
|
+
{
|
32
|
+
double meanZDifference = lhsSumZ / lhsSumW - rhsSumZ / rhsSumW;
|
33
|
+
improvement = lhsSumW * rhsSumW * pow(meanZDifference, 2) / (lhsSumW + rhsSumW);
|
34
|
+
}
|
35
|
+
else
|
36
|
+
{
|
37
|
+
double meanLRDifference = lhsSumZ / lhsSumW - rhsSumZ / rhsSumW;
|
38
|
+
double meanLMDifference = lhsSumZ / lhsSumW - missingSumZ / missingSumW;
|
39
|
+
double meanRMDifference = rhsSumZ / rhsSumW - missingSumZ / missingSumW;
|
40
|
+
|
41
|
+
improvement += lhsSumW * rhsSumW * pow(meanLRDifference, 2);
|
42
|
+
improvement += lhsSumW * missingSumW * pow(meanLMDifference, 2);
|
43
|
+
improvement += rhsSumW * missingSumW * pow(meanRMDifference, 2);
|
44
|
+
improvement /= (lhsSumW + rhsSumW + missingSumW);
|
45
|
+
}
|
46
|
+
|
47
|
+
return improvement;
|
48
|
+
}
|
49
|
+
|
50
|
+
// double NodeSplitter::calculateImprovement(double lhsSumW, double lhsSumZ, double rhsSumW, double rhsSumZ, double missingSumW, double missingSumZ)
|
51
|
+
// {
|
52
|
+
// double improvement = 0.0;
|
53
|
+
//
|
54
|
+
// if (lhsSumW)
|
55
|
+
// improvement += pow(lhsSumZ, 2) / lhsSumW;
|
56
|
+
//
|
57
|
+
// if (rhsSumW)
|
58
|
+
// improvement += pow(rhsSumZ, 2) / rhsSumW;
|
59
|
+
//
|
60
|
+
// if (missingSumW)
|
61
|
+
// improvement += pow(missingSumZ, 2) / missingSumW;
|
62
|
+
//
|
63
|
+
// return improvement;
|
64
|
+
// }
|
65
|
+
|
66
|
+
shared_ptr<DecisionTreeNode> NodeSplitter::createLhsChild( shared_ptr<SplitDefinition> splitDefinition )
|
67
|
+
{
|
68
|
+
return createChild(splitDefinition, LHS);
|
69
|
+
}
|
70
|
+
|
71
|
+
shared_ptr<DecisionTreeNode> NodeSplitter::createRhsChild( shared_ptr<SplitDefinition> splitDefinition )
|
72
|
+
{
|
73
|
+
return createChild(splitDefinition, RHS);
|
74
|
+
}
|
75
|
+
|
76
|
+
shared_ptr<DecisionTreeNode> NodeSplitter::createMissingChild( shared_ptr<SplitDefinition> splitDefinition )
|
77
|
+
{
|
78
|
+
return createChild(splitDefinition, MISSING);
|
79
|
+
}
|
80
|
+
|
81
|
+
shared_ptr<DecisionTreeNode> NodeSplitter::createChild( shared_ptr<SplitDefinition> splitDefinition, Partition partition )
|
82
|
+
{
|
83
|
+
vector<shared_ptr<DecisionTreeExperiment> > experiments = splitDefinition->getNodeToSplit()->getExperiments();
|
84
|
+
vector<shared_ptr<DecisionTreeExperiment> > childExperiments =
|
85
|
+
partitionExperiments(experiments, splitDefinition, partition);
|
86
|
+
|
87
|
+
double sumZ;
|
88
|
+
double sumW;
|
89
|
+
if (partition == LHS)
|
90
|
+
{
|
91
|
+
sumZ = splitDefinition->getLhsSumZ();
|
92
|
+
sumW = splitDefinition->getLhsSumW();
|
93
|
+
}
|
94
|
+
else if (partition == RHS)
|
95
|
+
{
|
96
|
+
sumZ = splitDefinition->getRhsSumZ();
|
97
|
+
sumW = splitDefinition->getRhsSumW();
|
98
|
+
}
|
99
|
+
else
|
100
|
+
{
|
101
|
+
sumZ = splitDefinition->getMissingSumZ();
|
102
|
+
sumW = splitDefinition->getMissingSumW();
|
103
|
+
}
|
104
|
+
shared_ptr<DecisionTreeNode> child =
|
105
|
+
shared_ptr<DecisionTreeNode>(new DecisionTreeNode(childExperiments, sumZ, sumW, partition, splitDefinition));
|
106
|
+
return child;
|
107
|
+
}
|
108
|
+
|
109
|
+
vector<shared_ptr<DecisionTreeExperiment> > NodeSplitter::partitionExperiments(vector<shared_ptr<DecisionTreeExperiment> >& experiments,
|
110
|
+
shared_ptr<SplitDefinition> splitDefinition, Partition partition)
|
111
|
+
{
|
112
|
+
bool rhs = !partition;
|
113
|
+
vector<shared_ptr<DecisionTreeExperiment> > partitionExperiments;
|
114
|
+
|
115
|
+
if (partition == LHS)
|
116
|
+
partitionExperiments.reserve(splitDefinition->getLhsExperimentCount());
|
117
|
+
else if (partition == RHS)
|
118
|
+
partitionExperiments.reserve(splitDefinition->getRhsExperimentCount());
|
119
|
+
else if (partition == MISSING)
|
120
|
+
partitionExperiments.reserve(splitDefinition->getMissingExperimentCount());
|
121
|
+
|
122
|
+
int featureIndex = splitDefinition->getFeatureIndex();
|
123
|
+
|
124
|
+
if (splitDefinition->isCategorical())
|
125
|
+
{
|
126
|
+
// categorical
|
127
|
+
set<double>& lhsCategories = splitDefinition->getLhsCategories();
|
128
|
+
set<double>& rhsCategories = splitDefinition->getRhsCategories();
|
129
|
+
|
130
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
131
|
+
{
|
132
|
+
double featureValue = experiment->getFeatureValue(featureIndex);
|
133
|
+
set<double>::const_iterator lhsIt = lhsCategories.find(featureValue);
|
134
|
+
set<double>::const_iterator rhsIt = rhsCategories.find(featureValue);
|
135
|
+
|
136
|
+
if ((partition == MISSING && m_missingValueDefined && m_missingValue == featureValue) ||
|
137
|
+
(partition == LHS && lhsIt != lhsCategories.end()) ||
|
138
|
+
(partition == RHS && rhsIt != rhsCategories.end()))
|
139
|
+
{
|
140
|
+
partitionExperiments.push_back(experiment);
|
141
|
+
}
|
142
|
+
}
|
143
|
+
}
|
144
|
+
else
|
145
|
+
{
|
146
|
+
// continuous
|
147
|
+
double splitValue = splitDefinition->getSplitValue();
|
148
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
149
|
+
{
|
150
|
+
double featureValue = experiment->getFeatureValue(featureIndex);
|
151
|
+
|
152
|
+
if (m_missingValueDefined && featureValue == m_missingValue)
|
153
|
+
{
|
154
|
+
// experiment has a missing value
|
155
|
+
if (partition == MISSING)
|
156
|
+
partitionExperiments.push_back(experiment);
|
157
|
+
}
|
158
|
+
else if ((partition == LHS && featureValue < splitValue) ||
|
159
|
+
(partition == RHS && featureValue > splitValue))
|
160
|
+
partitionExperiments.push_back(experiment);
|
161
|
+
}
|
162
|
+
}
|
163
|
+
return partitionExperiments;
|
164
|
+
}
|
165
|
+
|
166
|
+
vector<shared_ptr<DecisionTreeNode> > NodeSplitter::splitNode( shared_ptr<DecisionTreeNode> nodeToSplit, vector<int> featuresToConsider )
|
167
|
+
{
|
168
|
+
vector<shared_ptr<DecisionTreeNode> > children;
|
169
|
+
|
170
|
+
if (nodeToSplit->getSumW() == 0)
|
171
|
+
return children;
|
172
|
+
|
173
|
+
// find terminal node with best improvement for any of those variables
|
174
|
+
shared_ptr<SplitDefinition> bestSplit;
|
175
|
+
double bestImprovement = 0.0;
|
176
|
+
|
177
|
+
vector<double> vecImprovements;
|
178
|
+
vector<shared_ptr<SplitDefinition> > vecSplits;
|
179
|
+
|
180
|
+
vecImprovements.reserve(featuresToConsider.size());
|
181
|
+
vecSplits.reserve(featuresToConsider.size());
|
182
|
+
|
183
|
+
set<int>& categoricalFeatures = m_data->getCategoricalFeatureIndices();
|
184
|
+
|
185
|
+
BOOST_FOREACH(int featureIndex, featuresToConsider)
|
186
|
+
{
|
187
|
+
|
188
|
+
shared_ptr<SplitDefinition> split;
|
189
|
+
|
190
|
+
if (Utils::hasElement(categoricalFeatures,featureIndex))
|
191
|
+
split = createCategoricalSplitDefinition(nodeToSplit, featureIndex);
|
192
|
+
else
|
193
|
+
split = createContinuousSplitDefinition(nodeToSplit, featureIndex);
|
194
|
+
|
195
|
+
vecSplits.push_back(split);
|
196
|
+
vecImprovements.push_back(split.get() ? split->getImprovement() : 0);
|
197
|
+
|
198
|
+
if (!split.get()) // it returned an invalid
|
199
|
+
continue;
|
200
|
+
|
201
|
+
if (split->getImprovement() > bestImprovement)
|
202
|
+
{
|
203
|
+
bestImprovement = split->getImprovement();
|
204
|
+
bestSplit = split;
|
205
|
+
}
|
206
|
+
}
|
207
|
+
if (bestImprovement == 0.0)
|
208
|
+
return children;
|
209
|
+
|
210
|
+
if (m_scale != std::numeric_limits<double>::infinity() && vecImprovements.size() > 1)
|
211
|
+
{
|
212
|
+
vector<float> exp_u;
|
213
|
+
BOOST_FOREACH(double improvement, vecImprovements)
|
214
|
+
exp_u.push_back(m_scale * improvement / bestImprovement);
|
215
|
+
|
216
|
+
vector<float> pdf = StochasticUtils::convertHistogramToPdf(exp_u);
|
217
|
+
int bestIndex = StochasticUtils::chooseCategoryFromPdf(pdf);
|
218
|
+
bestImprovement = vecImprovements.at(bestIndex);
|
219
|
+
bestSplit = vecSplits.at(bestIndex);
|
220
|
+
}
|
221
|
+
|
222
|
+
int featureIndex = bestSplit->getFeatureIndex();
|
223
|
+
bool isCategorical = Utils::hasElement(categoricalFeatures,featureIndex);
|
224
|
+
|
225
|
+
shared_ptr<DecisionTreeNode> lhsChild = createLhsChild(bestSplit);
|
226
|
+
shared_ptr<DecisionTreeNode> rhsChild = createRhsChild(bestSplit);
|
227
|
+
shared_ptr<DecisionTreeNode> missingChild = createMissingChild(bestSplit);
|
228
|
+
|
229
|
+
nodeToSplit->defineSplit(bestSplit,lhsChild,rhsChild,missingChild);
|
230
|
+
|
231
|
+
// if (m_parameters.verbose)
|
232
|
+
// vlcMessage.Write("Split at feature index " + ToString(bestSplit->getFeatureIndex()) + " at value " + ToString(bestSplit->getSplitValue()) + " with improvement " + ToString(bestSplit->getImprovement()));
|
233
|
+
|
234
|
+
// finally, remove the node we just split from the terminal nodes, and add the children
|
235
|
+
children.push_back(lhsChild);
|
236
|
+
children.push_back(rhsChild);
|
237
|
+
children.push_back(missingChild);
|
238
|
+
|
239
|
+
return children;
|
240
|
+
}
|
241
|
+
|
242
|
+
struct FeatureSorter
|
243
|
+
{
|
244
|
+
FeatureSorter()
|
245
|
+
: featureIndexToSort(-1)
|
246
|
+
{}
|
247
|
+
|
248
|
+
int featureIndexToSort;
|
249
|
+
|
250
|
+
bool operator() (shared_ptr<DecisionTreeExperiment> a, shared_ptr<DecisionTreeExperiment> b)
|
251
|
+
{
|
252
|
+
if (featureIndexToSort == -1)
|
253
|
+
throw std::runtime_error("SortOnFeature object doesn't know which feature to sort on!");
|
254
|
+
|
255
|
+
return a->getFeatureValue(featureIndexToSort) < b->getFeatureValue(featureIndexToSort);
|
256
|
+
}
|
257
|
+
} featureSorter;
|
258
|
+
|
259
|
+
shared_ptr<SplitDefinition> NodeSplitter::createContinuousSplitDefinition( shared_ptr<DecisionTreeNode> node, int featureIndex )
|
260
|
+
{
|
261
|
+
vector<shared_ptr<DecisionTreeExperiment> > sortedExperiments = node->getExperiments();
|
262
|
+
// vector<shared_ptr<DecisionTreeExperiment> >& sortedExperiments = node->getSortedExperimentsForFeature(featureIndex);
|
263
|
+
|
264
|
+
featureSorter.featureIndexToSort = featureIndex;
|
265
|
+
sort(sortedExperiments.begin(), sortedExperiments.end(), featureSorter);
|
266
|
+
|
267
|
+
double rhsSumZ = 0, rhsSumW = 0, lhsSumZ = 0, lhsSumW = 0;
|
268
|
+
double missingSumZ = 0, missingSumW = 0;
|
269
|
+
|
270
|
+
vector<double> vecLhsSumZ;
|
271
|
+
vector<double> vecLhsSumW;
|
272
|
+
vector<int> vecLhsCount;
|
273
|
+
vector<double> vecRhsSumZ;
|
274
|
+
vector<double> vecRhsSumW;
|
275
|
+
vector<int> vecRhsCount;
|
276
|
+
vector<double> vecMissingSumZ;
|
277
|
+
vector<double> vecMissingSumW;
|
278
|
+
vector<int> vecMissingCount;
|
279
|
+
vector<double> vecImprovement;
|
280
|
+
vector<int> vecPosition;
|
281
|
+
|
282
|
+
double bestLhsSumZ;
|
283
|
+
double bestLhsSumW;
|
284
|
+
int bestLhsCount;
|
285
|
+
double bestRhsSumZ;
|
286
|
+
double bestRhsSumW;
|
287
|
+
int bestRhsCount;
|
288
|
+
double bestMissingSumZ;
|
289
|
+
double bestMissingSumW;
|
290
|
+
int bestMissingCount;
|
291
|
+
|
292
|
+
double bestImprovement = 0.0;
|
293
|
+
int bestPosition = -1;
|
294
|
+
int bestIndex = -1;
|
295
|
+
|
296
|
+
int lhsCount = 0, missingCount = 0;
|
297
|
+
int rhsCount = (int) sortedExperiments.size();
|
298
|
+
|
299
|
+
rhsSumZ = node->getSumZ();
|
300
|
+
rhsSumW = node->getSumW();
|
301
|
+
int position = -1;
|
302
|
+
double previousFeatureValue = 0;
|
303
|
+
|
304
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment> experiment, sortedExperiments)
|
305
|
+
{
|
306
|
+
double featureValue = experiment->getFeatureValue(featureIndex);
|
307
|
+
|
308
|
+
if (featureValue != previousFeatureValue)
|
309
|
+
{
|
310
|
+
// vlcMessage.Write("featureValue != previousFeatureValue => " + ToString(featureValue != previousFeatureValue));
|
311
|
+
// vlcMessage.Write("lhsSumW => " + ToString(lhsSumW));
|
312
|
+
// vlcMessage.Write("lhsSumZ => " + ToString(lhsSumZ));
|
313
|
+
// vlcMessage.Write("rhsSumW => " + ToString(rhsSumW));
|
314
|
+
// vlcMessage.Write("rhsSumZ => " + ToString(rhsSumZ));
|
315
|
+
// vlcMessage.Write("missingSumW => " + ToString(missingSumW));
|
316
|
+
// vlcMessage.Write("missingSumZ => " + ToString(missingSumZ));
|
317
|
+
// vlcMessage.Write("improvement => " + ToString(improvement));
|
318
|
+
// vlcMessage.Write("bestImprovement => " + ToString(bestImprovement));
|
319
|
+
// vlcMessage.Write("m_minObservations => " + ToString(m_minObservations));
|
320
|
+
}
|
321
|
+
|
322
|
+
if (featureValue != previousFeatureValue &&
|
323
|
+
lhsCount >= m_minObservations &&
|
324
|
+
rhsCount >= m_minObservations
|
325
|
+
)
|
326
|
+
{
|
327
|
+
double improvement = calculateImprovement(lhsSumW, lhsSumZ, rhsSumW, rhsSumZ, missingSumW, missingSumZ);
|
328
|
+
vecPosition.push_back(position);
|
329
|
+
vecImprovement.push_back(improvement);
|
330
|
+
vecLhsSumZ.push_back(lhsSumZ);
|
331
|
+
vecLhsSumW.push_back(lhsSumW);
|
332
|
+
vecLhsCount.push_back(lhsCount);
|
333
|
+
vecRhsSumZ.push_back(rhsSumZ);
|
334
|
+
vecRhsSumW.push_back(rhsSumW);
|
335
|
+
vecRhsCount.push_back(rhsCount);
|
336
|
+
vecMissingSumZ.push_back(missingSumZ);
|
337
|
+
vecMissingSumW.push_back(missingSumW);
|
338
|
+
vecMissingCount.push_back(missingCount);
|
339
|
+
|
340
|
+
if (improvement > bestImprovement)
|
341
|
+
{
|
342
|
+
bestImprovement = improvement;
|
343
|
+
bestPosition = position;
|
344
|
+
bestIndex = (int) vecPosition.size() - 1;
|
345
|
+
}
|
346
|
+
// if (improvement > bestImprovement)
|
347
|
+
// {
|
348
|
+
// bestImprovement = improvement;
|
349
|
+
// bestPosition = position;
|
350
|
+
// bestLhsSumZ = lhsSumZ;
|
351
|
+
// bestLhsSumW = lhsSumW;
|
352
|
+
// bestLhsCount = lhsCount;
|
353
|
+
// bestRhsSumZ = rhsSumZ;
|
354
|
+
// bestRhsSumW = rhsSumW;
|
355
|
+
// bestRhsCount = rhsCount;
|
356
|
+
// bestMissingSumZ = missingSumZ;
|
357
|
+
// bestMissingSumW = missingSumW;
|
358
|
+
// bestMissingCount = missingCount;
|
359
|
+
// }
|
360
|
+
// vlcMessage.Write("improvement => " + ToString(improvement));
|
361
|
+
|
362
|
+
}
|
363
|
+
double weight = experiment->getWeight();
|
364
|
+
double z = experiment->getZ();
|
365
|
+
rhsSumZ -= weight * z;
|
366
|
+
rhsSumW -= weight;
|
367
|
+
--rhsCount;
|
368
|
+
|
369
|
+
if (m_missingValueDefined && featureValue == m_missingValue)
|
370
|
+
{
|
371
|
+
missingSumZ += weight * z;
|
372
|
+
missingSumW += weight;
|
373
|
+
++missingCount;
|
374
|
+
}
|
375
|
+
else
|
376
|
+
{
|
377
|
+
lhsSumZ += weight * z;
|
378
|
+
lhsSumW += weight;
|
379
|
+
++lhsCount;
|
380
|
+
}
|
381
|
+
|
382
|
+
previousFeatureValue = featureValue;
|
383
|
+
++position;
|
384
|
+
}
|
385
|
+
|
386
|
+
if (bestPosition == -1)
|
387
|
+
return shared_ptr<SplitDefinition>();
|
388
|
+
|
389
|
+
if (m_scale != std::numeric_limits<double>::infinity() && vecImprovement.size() > 1)
|
390
|
+
{
|
391
|
+
vector<float> exp_u;
|
392
|
+
exp_u.reserve(vecImprovement.size());
|
393
|
+
BOOST_FOREACH(double& improvement, vecImprovement)
|
394
|
+
{
|
395
|
+
exp_u.push_back(exp(m_scale * improvement / bestImprovement));
|
396
|
+
}
|
397
|
+
vector<float> pdf = StochasticUtils::convertHistogramToPdf(exp_u);
|
398
|
+
bestIndex = StochasticUtils::chooseCategoryFromPdf(pdf, "improvements");
|
399
|
+
}
|
400
|
+
|
401
|
+
bestLhsSumZ = vecLhsSumZ.at(bestIndex);
|
402
|
+
bestLhsSumW = vecLhsSumW.at(bestIndex);
|
403
|
+
bestLhsCount = vecLhsCount.at(bestIndex);
|
404
|
+
bestRhsSumZ = vecRhsSumZ.at(bestIndex);
|
405
|
+
bestRhsSumW = vecRhsSumW.at(bestIndex);
|
406
|
+
bestRhsCount = vecRhsCount.at(bestIndex);
|
407
|
+
bestMissingSumZ = vecMissingSumZ.at(bestIndex);
|
408
|
+
bestMissingSumW = vecMissingSumW.at(bestIndex);
|
409
|
+
bestMissingCount = vecMissingCount.at(bestIndex);
|
410
|
+
bestImprovement = vecImprovement.at(bestIndex);
|
411
|
+
bestPosition = vecPosition.at(bestIndex);
|
412
|
+
|
413
|
+
if (bestPosition >= (int) (sortedExperiments.size()-1))
|
414
|
+
throw std::runtime_error(string("Unexpected bestPosition: ") + lexical_cast<string>(bestPosition));
|
415
|
+
|
416
|
+
double lhsFeatureValue = sortedExperiments.at(bestPosition)->getFeatureValue(featureIndex);
|
417
|
+
double rhsFeatureValue = sortedExperiments.at(bestPosition + 1)->getFeatureValue(featureIndex);
|
418
|
+
|
419
|
+
double splitValue;
|
420
|
+
if (m_missingValueDefined && (lhsFeatureValue == m_missingValue))
|
421
|
+
splitValue = m_missingValue;
|
422
|
+
else
|
423
|
+
splitValue = 0.5 * (lhsFeatureValue + rhsFeatureValue);
|
424
|
+
|
425
|
+
|
426
|
+
shared_ptr<SplitDefinition> splitDefinition = shared_ptr<SplitDefinition>
|
427
|
+
(new SplitDefinition(node, featureIndex, splitValue, bestLhsSumZ, bestLhsSumW, bestLhsCount,
|
428
|
+
bestRhsSumZ, bestRhsSumW, bestRhsCount, bestMissingSumZ, bestMissingSumW,
|
429
|
+
bestMissingCount, bestImprovement));
|
430
|
+
|
431
|
+
// create SplitDefinition
|
432
|
+
return splitDefinition;
|
433
|
+
}
|
434
|
+
|
435
|
+
shared_ptr<SplitDefinition> NodeSplitter::createCategoricalSplitDefinition( shared_ptr<DecisionTreeNode> node, int featureIndex )
|
436
|
+
{
|
437
|
+
vector<shared_ptr<DecisionTreeExperiment> > experiments = node->getExperiments();
|
438
|
+
|
439
|
+
map<double, CategoryInfo> experimentsPerCategory;
|
440
|
+
|
441
|
+
double missingSumZ = 0, missingSumW = 0;
|
442
|
+
int missingCount = 0;
|
443
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
444
|
+
{
|
445
|
+
double featureValue = experiment->getFeatureValue(featureIndex);
|
446
|
+
|
447
|
+
if (m_missingValueDefined && m_missingValue == featureValue)
|
448
|
+
{
|
449
|
+
double w = experiment->getWeight();
|
450
|
+
double z = experiment->getZ();
|
451
|
+
missingSumZ += w * z;
|
452
|
+
missingSumW += w;
|
453
|
+
missingCount++;
|
454
|
+
}
|
455
|
+
else
|
456
|
+
{
|
457
|
+
CategoryInfo& info = experimentsPerCategory[featureValue];
|
458
|
+
info.category = featureValue;
|
459
|
+
info.addExperiment(experiment);
|
460
|
+
}
|
461
|
+
}
|
462
|
+
|
463
|
+
if (experimentsPerCategory.size() == 1)
|
464
|
+
return shared_ptr<SplitDefinition>(); // can't split one thing!
|
465
|
+
|
466
|
+
// put them into a vector to make sorting easier!
|
467
|
+
vector<CategoryInfo> categoryInfoVector;
|
468
|
+
typedef pair<double, CategoryInfo> ElementType;
|
469
|
+
BOOST_FOREACH(ElementType e, experimentsPerCategory)
|
470
|
+
categoryInfoVector.push_back(e.second);
|
471
|
+
|
472
|
+
sort(categoryInfoVector.begin(), categoryInfoVector.end());
|
473
|
+
|
474
|
+
double rhsSumZ = 0, rhsSumW = 0, lhsSumZ = 0, lhsSumW = 0;
|
475
|
+
double bestImprovement = 0.0;
|
476
|
+
int bestPosition = -1;
|
477
|
+
double bestLhsSumZ = -1;
|
478
|
+
double bestLhsSumW = -1;
|
479
|
+
int bestLhsCount = -1;
|
480
|
+
double bestRhsSumZ = -1;
|
481
|
+
double bestRhsSumW = -1;
|
482
|
+
int bestRhsCount = -1;
|
483
|
+
int lhsCount = 0, rhsCount = 0;
|
484
|
+
|
485
|
+
rhsSumZ = node->getSumZ();
|
486
|
+
rhsSumW = node->getSumW();
|
487
|
+
rhsCount = (int) experiments.size();
|
488
|
+
|
489
|
+
int position = -1;
|
490
|
+
|
491
|
+
BOOST_FOREACH(CategoryInfo& e, categoryInfoVector)
|
492
|
+
{
|
493
|
+
double improvement = calculateImprovement(lhsSumW, lhsSumZ, rhsSumW, rhsSumZ, missingSumW, missingSumZ);
|
494
|
+
if ( improvement > bestImprovement
|
495
|
+
&& lhsCount >= m_minObservations
|
496
|
+
&& rhsCount >= m_minObservations)
|
497
|
+
{
|
498
|
+
bestImprovement = improvement;
|
499
|
+
bestPosition = position;
|
500
|
+
bestLhsSumZ = lhsSumZ;
|
501
|
+
bestLhsSumW = lhsSumW;
|
502
|
+
bestLhsCount = lhsCount;
|
503
|
+
bestRhsSumZ = rhsSumZ;
|
504
|
+
bestRhsSumW = rhsSumW;
|
505
|
+
bestRhsCount = rhsCount;
|
506
|
+
}
|
507
|
+
|
508
|
+
++position;
|
509
|
+
rhsSumW -= e.sumW;
|
510
|
+
rhsSumZ -= e.sumZ;
|
511
|
+
rhsCount -= e.countN;
|
512
|
+
|
513
|
+
lhsSumW += e.sumW;
|
514
|
+
lhsSumZ += e.sumZ;
|
515
|
+
lhsCount += e.countN;
|
516
|
+
}
|
517
|
+
|
518
|
+
if (bestPosition == -1 && missingSumW == 0)
|
519
|
+
return shared_ptr<SplitDefinition>();
|
520
|
+
|
521
|
+
set<double> lhsCategories;
|
522
|
+
set<double> rhsCategories;
|
523
|
+
|
524
|
+
int index = -1;
|
525
|
+
BOOST_FOREACH(CategoryInfo& info, categoryInfoVector)
|
526
|
+
{
|
527
|
+
++index;
|
528
|
+
if (index <= bestPosition)
|
529
|
+
lhsCategories.insert(info.category);
|
530
|
+
else
|
531
|
+
rhsCategories.insert(info.category);
|
532
|
+
}
|
533
|
+
|
534
|
+
// we have what we need to create a split definition now
|
535
|
+
shared_ptr<SplitDefinition> splitDefinition = shared_ptr<SplitDefinition>
|
536
|
+
(new SplitDefinition(node, featureIndex, lhsCategories, rhsCategories, bestLhsSumZ, bestLhsSumW,
|
537
|
+
bestLhsCount, bestRhsSumZ, bestRhsSumW, bestRhsCount, missingSumZ, missingSumW, missingCount, bestImprovement));
|
538
|
+
|
539
|
+
return splitDefinition;
|
540
|
+
}
|
541
|
+
|
542
|
+
shared_ptr<SplitDefinition> NodeSplitter::createSplitDefinition( shared_ptr<DecisionTreeNode> node, int featureIndex )
|
543
|
+
{
|
544
|
+
if (Utils::hasElement(m_data->getCategoricalFeatureIndices(),featureIndex))
|
545
|
+
return createCategoricalSplitDefinition(node, featureIndex);
|
546
|
+
else
|
547
|
+
return createContinuousSplitDefinition(node, featureIndex);
|
548
|
+
}
|
549
|
+
|
550
|
+
|
551
|
+
|
@@ -0,0 +1,22 @@
|
|
1
|
+
#include "MachineLearning/DecisionTree/NodeSplitterCategorical.h"
|
2
|
+
#include "MachineLearning/DecisionTree/DecisionTreeNode.h"
|
3
|
+
#include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
|
4
|
+
#include "MachineLearning/DecisionTree/SplitDefinition.h"
|
5
|
+
#include "MachineLearning/DecisionTree/CategoryInfo.h"
|
6
|
+
|
7
|
+
NodeSplitterCategorical::NodeSplitterCategorical(MLData* data, int minObservations, double scale)
|
8
|
+
: NodeSplitter(data, minObservations, scale)
|
9
|
+
{
|
10
|
+
|
11
|
+
}
|
12
|
+
|
13
|
+
NodeSplitterCategorical::~NodeSplitterCategorical()
|
14
|
+
{
|
15
|
+
|
16
|
+
}
|
17
|
+
|
18
|
+
shared_ptr<SplitDefinition> NodeSplitterCategorical::createSplitDefinition(shared_ptr<DecisionTreeNode> node,
|
19
|
+
int featureIndex)
|
20
|
+
{
|
21
|
+
|
22
|
+
}
|
@@ -0,0 +1,21 @@
|
|
1
|
+
#include "MachineLearning/DecisionTree/NodeSplitterContinuous.h"
|
2
|
+
#include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
|
3
|
+
#include "MachineLearning/DecisionTree/SplitDefinition.h"
|
4
|
+
#include "MachineLearning/DecisionTree/DecisionTreeNode.h"
|
5
|
+
|
6
|
+
|
7
|
+
|
8
|
+
NodeSplitterContinuous::NodeSplitterContinuous(MLData* data, int minObservations, double scale)
|
9
|
+
: NodeSplitter(data, minObservations, scale)
|
10
|
+
{
|
11
|
+
}
|
12
|
+
|
13
|
+
NodeSplitterContinuous::~NodeSplitterContinuous()
|
14
|
+
{
|
15
|
+
|
16
|
+
}
|
17
|
+
|
18
|
+
shared_ptr<SplitDefinition> NodeSplitterContinuous::createSplitDefinition(shared_ptr<DecisionTreeNode> node, int featureIndex)
|
19
|
+
{
|
20
|
+
|
21
|
+
}
|