ml4r 0.1.4 → 0.1.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (33) hide show
  1. data/ext/ml4r/LinearRegression/LinearRegression.cpp +305 -0
  2. data/ext/ml4r/LinearRegression/OLSLinearRegression.cpp +75 -0
  3. data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeExperiment.cpp +50 -0
  4. data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeNode.cpp +195 -0
  5. data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitter.cpp +551 -0
  6. data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterCategorical.cpp +22 -0
  7. data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterContinuous.cpp +21 -0
  8. data/ext/ml4r/MachineLearning/DecisionTree/SplitDefinition.cpp +142 -0
  9. data/ext/ml4r/MachineLearning/GBM/BernoulliCalculator.cpp +95 -0
  10. data/ext/ml4r/MachineLearning/GBM/GBMEstimator.cpp +601 -0
  11. data/ext/ml4r/MachineLearning/GBM/GBMOutput.cpp +86 -0
  12. data/ext/ml4r/MachineLearning/GBM/GBMRunner.cpp +117 -0
  13. data/ext/ml4r/MachineLearning/GBM/GaussianCalculator.cpp +94 -0
  14. data/ext/ml4r/MachineLearning/GBM/ZenithGBM.cpp +317 -0
  15. data/ext/ml4r/MachineLearning/MLData/MLData.cpp +232 -0
  16. data/ext/ml4r/MachineLearning/MLData/MLDataFields.cpp +1 -0
  17. data/ext/ml4r/MachineLearning/MLData/MLDataReader.cpp +139 -0
  18. data/ext/ml4r/MachineLearning/MLData/ZenithMLData.cpp +96 -0
  19. data/ext/ml4r/MachineLearning/MLData/ZenithMLDataReader.cpp +113 -0
  20. data/ext/ml4r/MachineLearning/MLExperiment.cpp +69 -0
  21. data/ext/ml4r/MachineLearning/MLRunner.cpp +183 -0
  22. data/ext/ml4r/MachineLearning/MLUtils.cpp +15 -0
  23. data/ext/ml4r/MachineLearning/RandomForest/RandomForestEstimator.cpp +172 -0
  24. data/ext/ml4r/MachineLearning/RandomForest/RandomForestOutput.cpp +66 -0
  25. data/ext/ml4r/MachineLearning/RandomForest/RandomForestRunner.cpp +84 -0
  26. data/ext/ml4r/MachineLearning/RandomForest/ZenithRandomForest.cpp +184 -0
  27. data/ext/ml4r/ml4r.cpp +34 -0
  28. data/ext/ml4r/ml4r_wrap.cpp +15727 -0
  29. data/ext/ml4r/utils/MathUtils.cpp +204 -0
  30. data/ext/ml4r/utils/StochasticUtils.cpp +73 -0
  31. data/ext/ml4r/utils/Utils.cpp +14 -0
  32. data/ext/ml4r/utils/VlcMessage.cpp +3 -0
  33. metadata +33 -1
@@ -0,0 +1,551 @@
1
+ #include "MachineLearning/DecisionTree/NodeSplitter.h"
2
+ #include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
3
+ #include "MachineLearning/DecisionTree/SplitDefinition.h"
4
+ #include "MachineLearning/DecisionTree/DecisionTreeNode.h"
5
+ #include "MachineLearning/DecisionTree/CategoryInfo.h"
6
+ #include "MachineLearning/MLData/MLData.h"
7
+ #include "MachineLearning/GBM/GBMEstimator.h"
8
+ #include "utils/Utils.h"
9
+ #include "utils/StochasticUtils.h"
10
+
11
+ #include <boost/foreach.hpp>
12
+ #include <boost/lexical_cast.hpp>
13
+ #include <cmath>
14
+ using boost::lexical_cast;
15
+
16
+ NodeSplitter::NodeSplitter(MLData* data, int minObservations, double scale)
17
+ : m_data(data), m_minObservations(minObservations),m_scale(scale)
18
+ {
19
+ m_missingValueDefined = m_data->missingValueDefined();
20
+ if (m_missingValueDefined)
21
+ m_missingValue = m_data->getMissingValue();
22
+ }
23
+
24
+ NodeSplitter::~NodeSplitter() {}
25
+
26
+ double NodeSplitter::calculateImprovement(double lhsSumW, double lhsSumZ, double rhsSumW, double rhsSumZ, double missingSumW, double missingSumZ)
27
+ {
28
+ double improvement = 0.0;
29
+
30
+ if (missingSumW == 0)
31
+ {
32
+ double meanZDifference = lhsSumZ / lhsSumW - rhsSumZ / rhsSumW;
33
+ improvement = lhsSumW * rhsSumW * pow(meanZDifference, 2) / (lhsSumW + rhsSumW);
34
+ }
35
+ else
36
+ {
37
+ double meanLRDifference = lhsSumZ / lhsSumW - rhsSumZ / rhsSumW;
38
+ double meanLMDifference = lhsSumZ / lhsSumW - missingSumZ / missingSumW;
39
+ double meanRMDifference = rhsSumZ / rhsSumW - missingSumZ / missingSumW;
40
+
41
+ improvement += lhsSumW * rhsSumW * pow(meanLRDifference, 2);
42
+ improvement += lhsSumW * missingSumW * pow(meanLMDifference, 2);
43
+ improvement += rhsSumW * missingSumW * pow(meanRMDifference, 2);
44
+ improvement /= (lhsSumW + rhsSumW + missingSumW);
45
+ }
46
+
47
+ return improvement;
48
+ }
49
+
50
+ // double NodeSplitter::calculateImprovement(double lhsSumW, double lhsSumZ, double rhsSumW, double rhsSumZ, double missingSumW, double missingSumZ)
51
+ // {
52
+ // double improvement = 0.0;
53
+ //
54
+ // if (lhsSumW)
55
+ // improvement += pow(lhsSumZ, 2) / lhsSumW;
56
+ //
57
+ // if (rhsSumW)
58
+ // improvement += pow(rhsSumZ, 2) / rhsSumW;
59
+ //
60
+ // if (missingSumW)
61
+ // improvement += pow(missingSumZ, 2) / missingSumW;
62
+ //
63
+ // return improvement;
64
+ // }
65
+
66
+ shared_ptr<DecisionTreeNode> NodeSplitter::createLhsChild( shared_ptr<SplitDefinition> splitDefinition )
67
+ {
68
+ return createChild(splitDefinition, LHS);
69
+ }
70
+
71
+ shared_ptr<DecisionTreeNode> NodeSplitter::createRhsChild( shared_ptr<SplitDefinition> splitDefinition )
72
+ {
73
+ return createChild(splitDefinition, RHS);
74
+ }
75
+
76
+ shared_ptr<DecisionTreeNode> NodeSplitter::createMissingChild( shared_ptr<SplitDefinition> splitDefinition )
77
+ {
78
+ return createChild(splitDefinition, MISSING);
79
+ }
80
+
81
+ shared_ptr<DecisionTreeNode> NodeSplitter::createChild( shared_ptr<SplitDefinition> splitDefinition, Partition partition )
82
+ {
83
+ vector<shared_ptr<DecisionTreeExperiment> > experiments = splitDefinition->getNodeToSplit()->getExperiments();
84
+ vector<shared_ptr<DecisionTreeExperiment> > childExperiments =
85
+ partitionExperiments(experiments, splitDefinition, partition);
86
+
87
+ double sumZ;
88
+ double sumW;
89
+ if (partition == LHS)
90
+ {
91
+ sumZ = splitDefinition->getLhsSumZ();
92
+ sumW = splitDefinition->getLhsSumW();
93
+ }
94
+ else if (partition == RHS)
95
+ {
96
+ sumZ = splitDefinition->getRhsSumZ();
97
+ sumW = splitDefinition->getRhsSumW();
98
+ }
99
+ else
100
+ {
101
+ sumZ = splitDefinition->getMissingSumZ();
102
+ sumW = splitDefinition->getMissingSumW();
103
+ }
104
+ shared_ptr<DecisionTreeNode> child =
105
+ shared_ptr<DecisionTreeNode>(new DecisionTreeNode(childExperiments, sumZ, sumW, partition, splitDefinition));
106
+ return child;
107
+ }
108
+
109
+ vector<shared_ptr<DecisionTreeExperiment> > NodeSplitter::partitionExperiments(vector<shared_ptr<DecisionTreeExperiment> >& experiments,
110
+ shared_ptr<SplitDefinition> splitDefinition, Partition partition)
111
+ {
112
+ bool rhs = !partition;
113
+ vector<shared_ptr<DecisionTreeExperiment> > partitionExperiments;
114
+
115
+ if (partition == LHS)
116
+ partitionExperiments.reserve(splitDefinition->getLhsExperimentCount());
117
+ else if (partition == RHS)
118
+ partitionExperiments.reserve(splitDefinition->getRhsExperimentCount());
119
+ else if (partition == MISSING)
120
+ partitionExperiments.reserve(splitDefinition->getMissingExperimentCount());
121
+
122
+ int featureIndex = splitDefinition->getFeatureIndex();
123
+
124
+ if (splitDefinition->isCategorical())
125
+ {
126
+ // categorical
127
+ set<double>& lhsCategories = splitDefinition->getLhsCategories();
128
+ set<double>& rhsCategories = splitDefinition->getRhsCategories();
129
+
130
+ BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
131
+ {
132
+ double featureValue = experiment->getFeatureValue(featureIndex);
133
+ set<double>::const_iterator lhsIt = lhsCategories.find(featureValue);
134
+ set<double>::const_iterator rhsIt = rhsCategories.find(featureValue);
135
+
136
+ if ((partition == MISSING && m_missingValueDefined && m_missingValue == featureValue) ||
137
+ (partition == LHS && lhsIt != lhsCategories.end()) ||
138
+ (partition == RHS && rhsIt != rhsCategories.end()))
139
+ {
140
+ partitionExperiments.push_back(experiment);
141
+ }
142
+ }
143
+ }
144
+ else
145
+ {
146
+ // continuous
147
+ double splitValue = splitDefinition->getSplitValue();
148
+ BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
149
+ {
150
+ double featureValue = experiment->getFeatureValue(featureIndex);
151
+
152
+ if (m_missingValueDefined && featureValue == m_missingValue)
153
+ {
154
+ // experiment has a missing value
155
+ if (partition == MISSING)
156
+ partitionExperiments.push_back(experiment);
157
+ }
158
+ else if ((partition == LHS && featureValue < splitValue) ||
159
+ (partition == RHS && featureValue > splitValue))
160
+ partitionExperiments.push_back(experiment);
161
+ }
162
+ }
163
+ return partitionExperiments;
164
+ }
165
+
166
+ vector<shared_ptr<DecisionTreeNode> > NodeSplitter::splitNode( shared_ptr<DecisionTreeNode> nodeToSplit, vector<int> featuresToConsider )
167
+ {
168
+ vector<shared_ptr<DecisionTreeNode> > children;
169
+
170
+ if (nodeToSplit->getSumW() == 0)
171
+ return children;
172
+
173
+ // find terminal node with best improvement for any of those variables
174
+ shared_ptr<SplitDefinition> bestSplit;
175
+ double bestImprovement = 0.0;
176
+
177
+ vector<double> vecImprovements;
178
+ vector<shared_ptr<SplitDefinition> > vecSplits;
179
+
180
+ vecImprovements.reserve(featuresToConsider.size());
181
+ vecSplits.reserve(featuresToConsider.size());
182
+
183
+ set<int>& categoricalFeatures = m_data->getCategoricalFeatureIndices();
184
+
185
+ BOOST_FOREACH(int featureIndex, featuresToConsider)
186
+ {
187
+
188
+ shared_ptr<SplitDefinition> split;
189
+
190
+ if (Utils::hasElement(categoricalFeatures,featureIndex))
191
+ split = createCategoricalSplitDefinition(nodeToSplit, featureIndex);
192
+ else
193
+ split = createContinuousSplitDefinition(nodeToSplit, featureIndex);
194
+
195
+ vecSplits.push_back(split);
196
+ vecImprovements.push_back(split.get() ? split->getImprovement() : 0);
197
+
198
+ if (!split.get()) // it returned an invalid
199
+ continue;
200
+
201
+ if (split->getImprovement() > bestImprovement)
202
+ {
203
+ bestImprovement = split->getImprovement();
204
+ bestSplit = split;
205
+ }
206
+ }
207
+ if (bestImprovement == 0.0)
208
+ return children;
209
+
210
+ if (m_scale != std::numeric_limits<double>::infinity() && vecImprovements.size() > 1)
211
+ {
212
+ vector<float> exp_u;
213
+ BOOST_FOREACH(double improvement, vecImprovements)
214
+ exp_u.push_back(m_scale * improvement / bestImprovement);
215
+
216
+ vector<float> pdf = StochasticUtils::convertHistogramToPdf(exp_u);
217
+ int bestIndex = StochasticUtils::chooseCategoryFromPdf(pdf);
218
+ bestImprovement = vecImprovements.at(bestIndex);
219
+ bestSplit = vecSplits.at(bestIndex);
220
+ }
221
+
222
+ int featureIndex = bestSplit->getFeatureIndex();
223
+ bool isCategorical = Utils::hasElement(categoricalFeatures,featureIndex);
224
+
225
+ shared_ptr<DecisionTreeNode> lhsChild = createLhsChild(bestSplit);
226
+ shared_ptr<DecisionTreeNode> rhsChild = createRhsChild(bestSplit);
227
+ shared_ptr<DecisionTreeNode> missingChild = createMissingChild(bestSplit);
228
+
229
+ nodeToSplit->defineSplit(bestSplit,lhsChild,rhsChild,missingChild);
230
+
231
+ // if (m_parameters.verbose)
232
+ // vlcMessage.Write("Split at feature index " + ToString(bestSplit->getFeatureIndex()) + " at value " + ToString(bestSplit->getSplitValue()) + " with improvement " + ToString(bestSplit->getImprovement()));
233
+
234
+ // finally, remove the node we just split from the terminal nodes, and add the children
235
+ children.push_back(lhsChild);
236
+ children.push_back(rhsChild);
237
+ children.push_back(missingChild);
238
+
239
+ return children;
240
+ }
241
+
242
+ struct FeatureSorter
243
+ {
244
+ FeatureSorter()
245
+ : featureIndexToSort(-1)
246
+ {}
247
+
248
+ int featureIndexToSort;
249
+
250
+ bool operator() (shared_ptr<DecisionTreeExperiment> a, shared_ptr<DecisionTreeExperiment> b)
251
+ {
252
+ if (featureIndexToSort == -1)
253
+ throw std::runtime_error("SortOnFeature object doesn't know which feature to sort on!");
254
+
255
+ return a->getFeatureValue(featureIndexToSort) < b->getFeatureValue(featureIndexToSort);
256
+ }
257
+ } featureSorter;
258
+
259
+ shared_ptr<SplitDefinition> NodeSplitter::createContinuousSplitDefinition( shared_ptr<DecisionTreeNode> node, int featureIndex )
260
+ {
261
+ vector<shared_ptr<DecisionTreeExperiment> > sortedExperiments = node->getExperiments();
262
+ // vector<shared_ptr<DecisionTreeExperiment> >& sortedExperiments = node->getSortedExperimentsForFeature(featureIndex);
263
+
264
+ featureSorter.featureIndexToSort = featureIndex;
265
+ sort(sortedExperiments.begin(), sortedExperiments.end(), featureSorter);
266
+
267
+ double rhsSumZ = 0, rhsSumW = 0, lhsSumZ = 0, lhsSumW = 0;
268
+ double missingSumZ = 0, missingSumW = 0;
269
+
270
+ vector<double> vecLhsSumZ;
271
+ vector<double> vecLhsSumW;
272
+ vector<int> vecLhsCount;
273
+ vector<double> vecRhsSumZ;
274
+ vector<double> vecRhsSumW;
275
+ vector<int> vecRhsCount;
276
+ vector<double> vecMissingSumZ;
277
+ vector<double> vecMissingSumW;
278
+ vector<int> vecMissingCount;
279
+ vector<double> vecImprovement;
280
+ vector<int> vecPosition;
281
+
282
+ double bestLhsSumZ;
283
+ double bestLhsSumW;
284
+ int bestLhsCount;
285
+ double bestRhsSumZ;
286
+ double bestRhsSumW;
287
+ int bestRhsCount;
288
+ double bestMissingSumZ;
289
+ double bestMissingSumW;
290
+ int bestMissingCount;
291
+
292
+ double bestImprovement = 0.0;
293
+ int bestPosition = -1;
294
+ int bestIndex = -1;
295
+
296
+ int lhsCount = 0, missingCount = 0;
297
+ int rhsCount = (int) sortedExperiments.size();
298
+
299
+ rhsSumZ = node->getSumZ();
300
+ rhsSumW = node->getSumW();
301
+ int position = -1;
302
+ double previousFeatureValue = 0;
303
+
304
+ BOOST_FOREACH(shared_ptr<DecisionTreeExperiment> experiment, sortedExperiments)
305
+ {
306
+ double featureValue = experiment->getFeatureValue(featureIndex);
307
+
308
+ if (featureValue != previousFeatureValue)
309
+ {
310
+ // vlcMessage.Write("featureValue != previousFeatureValue => " + ToString(featureValue != previousFeatureValue));
311
+ // vlcMessage.Write("lhsSumW => " + ToString(lhsSumW));
312
+ // vlcMessage.Write("lhsSumZ => " + ToString(lhsSumZ));
313
+ // vlcMessage.Write("rhsSumW => " + ToString(rhsSumW));
314
+ // vlcMessage.Write("rhsSumZ => " + ToString(rhsSumZ));
315
+ // vlcMessage.Write("missingSumW => " + ToString(missingSumW));
316
+ // vlcMessage.Write("missingSumZ => " + ToString(missingSumZ));
317
+ // vlcMessage.Write("improvement => " + ToString(improvement));
318
+ // vlcMessage.Write("bestImprovement => " + ToString(bestImprovement));
319
+ // vlcMessage.Write("m_minObservations => " + ToString(m_minObservations));
320
+ }
321
+
322
+ if (featureValue != previousFeatureValue &&
323
+ lhsCount >= m_minObservations &&
324
+ rhsCount >= m_minObservations
325
+ )
326
+ {
327
+ double improvement = calculateImprovement(lhsSumW, lhsSumZ, rhsSumW, rhsSumZ, missingSumW, missingSumZ);
328
+ vecPosition.push_back(position);
329
+ vecImprovement.push_back(improvement);
330
+ vecLhsSumZ.push_back(lhsSumZ);
331
+ vecLhsSumW.push_back(lhsSumW);
332
+ vecLhsCount.push_back(lhsCount);
333
+ vecRhsSumZ.push_back(rhsSumZ);
334
+ vecRhsSumW.push_back(rhsSumW);
335
+ vecRhsCount.push_back(rhsCount);
336
+ vecMissingSumZ.push_back(missingSumZ);
337
+ vecMissingSumW.push_back(missingSumW);
338
+ vecMissingCount.push_back(missingCount);
339
+
340
+ if (improvement > bestImprovement)
341
+ {
342
+ bestImprovement = improvement;
343
+ bestPosition = position;
344
+ bestIndex = (int) vecPosition.size() - 1;
345
+ }
346
+ // if (improvement > bestImprovement)
347
+ // {
348
+ // bestImprovement = improvement;
349
+ // bestPosition = position;
350
+ // bestLhsSumZ = lhsSumZ;
351
+ // bestLhsSumW = lhsSumW;
352
+ // bestLhsCount = lhsCount;
353
+ // bestRhsSumZ = rhsSumZ;
354
+ // bestRhsSumW = rhsSumW;
355
+ // bestRhsCount = rhsCount;
356
+ // bestMissingSumZ = missingSumZ;
357
+ // bestMissingSumW = missingSumW;
358
+ // bestMissingCount = missingCount;
359
+ // }
360
+ // vlcMessage.Write("improvement => " + ToString(improvement));
361
+
362
+ }
363
+ double weight = experiment->getWeight();
364
+ double z = experiment->getZ();
365
+ rhsSumZ -= weight * z;
366
+ rhsSumW -= weight;
367
+ --rhsCount;
368
+
369
+ if (m_missingValueDefined && featureValue == m_missingValue)
370
+ {
371
+ missingSumZ += weight * z;
372
+ missingSumW += weight;
373
+ ++missingCount;
374
+ }
375
+ else
376
+ {
377
+ lhsSumZ += weight * z;
378
+ lhsSumW += weight;
379
+ ++lhsCount;
380
+ }
381
+
382
+ previousFeatureValue = featureValue;
383
+ ++position;
384
+ }
385
+
386
+ if (bestPosition == -1)
387
+ return shared_ptr<SplitDefinition>();
388
+
389
+ if (m_scale != std::numeric_limits<double>::infinity() && vecImprovement.size() > 1)
390
+ {
391
+ vector<float> exp_u;
392
+ exp_u.reserve(vecImprovement.size());
393
+ BOOST_FOREACH(double& improvement, vecImprovement)
394
+ {
395
+ exp_u.push_back(exp(m_scale * improvement / bestImprovement));
396
+ }
397
+ vector<float> pdf = StochasticUtils::convertHistogramToPdf(exp_u);
398
+ bestIndex = StochasticUtils::chooseCategoryFromPdf(pdf, "improvements");
399
+ }
400
+
401
+ bestLhsSumZ = vecLhsSumZ.at(bestIndex);
402
+ bestLhsSumW = vecLhsSumW.at(bestIndex);
403
+ bestLhsCount = vecLhsCount.at(bestIndex);
404
+ bestRhsSumZ = vecRhsSumZ.at(bestIndex);
405
+ bestRhsSumW = vecRhsSumW.at(bestIndex);
406
+ bestRhsCount = vecRhsCount.at(bestIndex);
407
+ bestMissingSumZ = vecMissingSumZ.at(bestIndex);
408
+ bestMissingSumW = vecMissingSumW.at(bestIndex);
409
+ bestMissingCount = vecMissingCount.at(bestIndex);
410
+ bestImprovement = vecImprovement.at(bestIndex);
411
+ bestPosition = vecPosition.at(bestIndex);
412
+
413
+ if (bestPosition >= (int) (sortedExperiments.size()-1))
414
+ throw std::runtime_error(string("Unexpected bestPosition: ") + lexical_cast<string>(bestPosition));
415
+
416
+ double lhsFeatureValue = sortedExperiments.at(bestPosition)->getFeatureValue(featureIndex);
417
+ double rhsFeatureValue = sortedExperiments.at(bestPosition + 1)->getFeatureValue(featureIndex);
418
+
419
+ double splitValue;
420
+ if (m_missingValueDefined && (lhsFeatureValue == m_missingValue))
421
+ splitValue = m_missingValue;
422
+ else
423
+ splitValue = 0.5 * (lhsFeatureValue + rhsFeatureValue);
424
+
425
+
426
+ shared_ptr<SplitDefinition> splitDefinition = shared_ptr<SplitDefinition>
427
+ (new SplitDefinition(node, featureIndex, splitValue, bestLhsSumZ, bestLhsSumW, bestLhsCount,
428
+ bestRhsSumZ, bestRhsSumW, bestRhsCount, bestMissingSumZ, bestMissingSumW,
429
+ bestMissingCount, bestImprovement));
430
+
431
+ // create SplitDefinition
432
+ return splitDefinition;
433
+ }
434
+
435
+ shared_ptr<SplitDefinition> NodeSplitter::createCategoricalSplitDefinition( shared_ptr<DecisionTreeNode> node, int featureIndex )
436
+ {
437
+ vector<shared_ptr<DecisionTreeExperiment> > experiments = node->getExperiments();
438
+
439
+ map<double, CategoryInfo> experimentsPerCategory;
440
+
441
+ double missingSumZ = 0, missingSumW = 0;
442
+ int missingCount = 0;
443
+ BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
444
+ {
445
+ double featureValue = experiment->getFeatureValue(featureIndex);
446
+
447
+ if (m_missingValueDefined && m_missingValue == featureValue)
448
+ {
449
+ double w = experiment->getWeight();
450
+ double z = experiment->getZ();
451
+ missingSumZ += w * z;
452
+ missingSumW += w;
453
+ missingCount++;
454
+ }
455
+ else
456
+ {
457
+ CategoryInfo& info = experimentsPerCategory[featureValue];
458
+ info.category = featureValue;
459
+ info.addExperiment(experiment);
460
+ }
461
+ }
462
+
463
+ if (experimentsPerCategory.size() == 1)
464
+ return shared_ptr<SplitDefinition>(); // can't split one thing!
465
+
466
+ // put them into a vector to make sorting easier!
467
+ vector<CategoryInfo> categoryInfoVector;
468
+ typedef pair<double, CategoryInfo> ElementType;
469
+ BOOST_FOREACH(ElementType e, experimentsPerCategory)
470
+ categoryInfoVector.push_back(e.second);
471
+
472
+ sort(categoryInfoVector.begin(), categoryInfoVector.end());
473
+
474
+ double rhsSumZ = 0, rhsSumW = 0, lhsSumZ = 0, lhsSumW = 0;
475
+ double bestImprovement = 0.0;
476
+ int bestPosition = -1;
477
+ double bestLhsSumZ = -1;
478
+ double bestLhsSumW = -1;
479
+ int bestLhsCount = -1;
480
+ double bestRhsSumZ = -1;
481
+ double bestRhsSumW = -1;
482
+ int bestRhsCount = -1;
483
+ int lhsCount = 0, rhsCount = 0;
484
+
485
+ rhsSumZ = node->getSumZ();
486
+ rhsSumW = node->getSumW();
487
+ rhsCount = (int) experiments.size();
488
+
489
+ int position = -1;
490
+
491
+ BOOST_FOREACH(CategoryInfo& e, categoryInfoVector)
492
+ {
493
+ double improvement = calculateImprovement(lhsSumW, lhsSumZ, rhsSumW, rhsSumZ, missingSumW, missingSumZ);
494
+ if ( improvement > bestImprovement
495
+ && lhsCount >= m_minObservations
496
+ && rhsCount >= m_minObservations)
497
+ {
498
+ bestImprovement = improvement;
499
+ bestPosition = position;
500
+ bestLhsSumZ = lhsSumZ;
501
+ bestLhsSumW = lhsSumW;
502
+ bestLhsCount = lhsCount;
503
+ bestRhsSumZ = rhsSumZ;
504
+ bestRhsSumW = rhsSumW;
505
+ bestRhsCount = rhsCount;
506
+ }
507
+
508
+ ++position;
509
+ rhsSumW -= e.sumW;
510
+ rhsSumZ -= e.sumZ;
511
+ rhsCount -= e.countN;
512
+
513
+ lhsSumW += e.sumW;
514
+ lhsSumZ += e.sumZ;
515
+ lhsCount += e.countN;
516
+ }
517
+
518
+ if (bestPosition == -1 && missingSumW == 0)
519
+ return shared_ptr<SplitDefinition>();
520
+
521
+ set<double> lhsCategories;
522
+ set<double> rhsCategories;
523
+
524
+ int index = -1;
525
+ BOOST_FOREACH(CategoryInfo& info, categoryInfoVector)
526
+ {
527
+ ++index;
528
+ if (index <= bestPosition)
529
+ lhsCategories.insert(info.category);
530
+ else
531
+ rhsCategories.insert(info.category);
532
+ }
533
+
534
+ // we have what we need to create a split definition now
535
+ shared_ptr<SplitDefinition> splitDefinition = shared_ptr<SplitDefinition>
536
+ (new SplitDefinition(node, featureIndex, lhsCategories, rhsCategories, bestLhsSumZ, bestLhsSumW,
537
+ bestLhsCount, bestRhsSumZ, bestRhsSumW, bestRhsCount, missingSumZ, missingSumW, missingCount, bestImprovement));
538
+
539
+ return splitDefinition;
540
+ }
541
+
542
+ shared_ptr<SplitDefinition> NodeSplitter::createSplitDefinition( shared_ptr<DecisionTreeNode> node, int featureIndex )
543
+ {
544
+ if (Utils::hasElement(m_data->getCategoricalFeatureIndices(),featureIndex))
545
+ return createCategoricalSplitDefinition(node, featureIndex);
546
+ else
547
+ return createContinuousSplitDefinition(node, featureIndex);
548
+ }
549
+
550
+
551
+
@@ -0,0 +1,22 @@
1
+ #include "MachineLearning/DecisionTree/NodeSplitterCategorical.h"
2
+ #include "MachineLearning/DecisionTree/DecisionTreeNode.h"
3
+ #include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
4
+ #include "MachineLearning/DecisionTree/SplitDefinition.h"
5
+ #include "MachineLearning/DecisionTree/CategoryInfo.h"
6
+
7
+ NodeSplitterCategorical::NodeSplitterCategorical(MLData* data, int minObservations, double scale)
8
+ : NodeSplitter(data, minObservations, scale)
9
+ {
10
+
11
+ }
12
+
13
+ NodeSplitterCategorical::~NodeSplitterCategorical()
14
+ {
15
+
16
+ }
17
+
18
+ shared_ptr<SplitDefinition> NodeSplitterCategorical::createSplitDefinition(shared_ptr<DecisionTreeNode> node,
19
+ int featureIndex)
20
+ {
21
+
22
+ }
@@ -0,0 +1,21 @@
1
+ #include "MachineLearning/DecisionTree/NodeSplitterContinuous.h"
2
+ #include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
3
+ #include "MachineLearning/DecisionTree/SplitDefinition.h"
4
+ #include "MachineLearning/DecisionTree/DecisionTreeNode.h"
5
+
6
+
7
+
8
+ NodeSplitterContinuous::NodeSplitterContinuous(MLData* data, int minObservations, double scale)
9
+ : NodeSplitter(data, minObservations, scale)
10
+ {
11
+ }
12
+
13
+ NodeSplitterContinuous::~NodeSplitterContinuous()
14
+ {
15
+
16
+ }
17
+
18
+ shared_ptr<SplitDefinition> NodeSplitterContinuous::createSplitDefinition(shared_ptr<DecisionTreeNode> node, int featureIndex)
19
+ {
20
+
21
+ }