ml4r 0.1.4 → 0.1.5
Sign up to get free protection for your applications and to get access to all the features.
- data/ext/ml4r/LinearRegression/LinearRegression.cpp +305 -0
- data/ext/ml4r/LinearRegression/OLSLinearRegression.cpp +75 -0
- data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeExperiment.cpp +50 -0
- data/ext/ml4r/MachineLearning/DecisionTree/DecisionTreeNode.cpp +195 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitter.cpp +551 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterCategorical.cpp +22 -0
- data/ext/ml4r/MachineLearning/DecisionTree/NodeSplitterContinuous.cpp +21 -0
- data/ext/ml4r/MachineLearning/DecisionTree/SplitDefinition.cpp +142 -0
- data/ext/ml4r/MachineLearning/GBM/BernoulliCalculator.cpp +95 -0
- data/ext/ml4r/MachineLearning/GBM/GBMEstimator.cpp +601 -0
- data/ext/ml4r/MachineLearning/GBM/GBMOutput.cpp +86 -0
- data/ext/ml4r/MachineLearning/GBM/GBMRunner.cpp +117 -0
- data/ext/ml4r/MachineLearning/GBM/GaussianCalculator.cpp +94 -0
- data/ext/ml4r/MachineLearning/GBM/ZenithGBM.cpp +317 -0
- data/ext/ml4r/MachineLearning/MLData/MLData.cpp +232 -0
- data/ext/ml4r/MachineLearning/MLData/MLDataFields.cpp +1 -0
- data/ext/ml4r/MachineLearning/MLData/MLDataReader.cpp +139 -0
- data/ext/ml4r/MachineLearning/MLData/ZenithMLData.cpp +96 -0
- data/ext/ml4r/MachineLearning/MLData/ZenithMLDataReader.cpp +113 -0
- data/ext/ml4r/MachineLearning/MLExperiment.cpp +69 -0
- data/ext/ml4r/MachineLearning/MLRunner.cpp +183 -0
- data/ext/ml4r/MachineLearning/MLUtils.cpp +15 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestEstimator.cpp +172 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestOutput.cpp +66 -0
- data/ext/ml4r/MachineLearning/RandomForest/RandomForestRunner.cpp +84 -0
- data/ext/ml4r/MachineLearning/RandomForest/ZenithRandomForest.cpp +184 -0
- data/ext/ml4r/ml4r.cpp +34 -0
- data/ext/ml4r/ml4r_wrap.cpp +15727 -0
- data/ext/ml4r/utils/MathUtils.cpp +204 -0
- data/ext/ml4r/utils/StochasticUtils.cpp +73 -0
- data/ext/ml4r/utils/Utils.cpp +14 -0
- data/ext/ml4r/utils/VlcMessage.cpp +3 -0
- metadata +33 -1
@@ -0,0 +1,117 @@
|
|
1
|
+
|
2
|
+
#include "MachineLearning/GBM/GBMRunner.h"
|
3
|
+
#include "MachineLearning/GBM/GBMEstimator.h"
|
4
|
+
#include "MachineLearning/GBM/GBMOutput.h"
|
5
|
+
#include "MachineLearning/GBM/BernoulliCalculator.h"
|
6
|
+
#include "MachineLearning/GBM/GaussianCalculator.h"
|
7
|
+
#include "MachineLearning/MLData/MLData.h"
|
8
|
+
#include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
|
9
|
+
#include "MachineLearning/DecisionTree/DecisionTreeNode.h"
|
10
|
+
#include "MachineLearning/DecisionTree/FeatureInteraction.h"
|
11
|
+
|
12
|
+
#include "utils/VlcMessage.h"
|
13
|
+
|
14
|
+
// #ifdef TBB_USE_THREADING_TOOLS
|
15
|
+
// #undef TBB_USE_THREADING_TOOLS
|
16
|
+
// #endif
|
17
|
+
// #define TBB_USE_THREADING_TOOLS 1
|
18
|
+
// #include "tbb/task_scheduler_init.h"
|
19
|
+
// #include "tbb/parallel_for.h"
|
20
|
+
// #include "tbb/blocked_range.h"
|
21
|
+
// #include "tbb/explicit_range.h"
|
22
|
+
|
23
|
+
#include <math.h>
|
24
|
+
#include <boost/pointer_cast.hpp>
|
25
|
+
#include <boost/make_shared.hpp>
|
26
|
+
#include <boost/foreach.hpp>
|
27
|
+
using boost::make_shared;
|
28
|
+
using boost::dynamic_pointer_cast;
|
29
|
+
|
30
|
+
|
31
|
+
GBMRunner::GBMRunner()
|
32
|
+
{
|
33
|
+
parameters = make_shared<GBMParameters>();
|
34
|
+
}
|
35
|
+
|
36
|
+
GBMRunner::~GBMRunner()
|
37
|
+
{
|
38
|
+
|
39
|
+
}
|
40
|
+
|
41
|
+
void GBMRunner::config()
|
42
|
+
{
|
43
|
+
|
44
|
+
vector<string>& dataFeatures = m_data->getFeatures();
|
45
|
+
|
46
|
+
// parameters->loadedFeatures = dataFeatures;
|
47
|
+
if (parameters->featuresToRun.empty())
|
48
|
+
parameters->featuresToRun = dataFeatures;
|
49
|
+
else
|
50
|
+
{
|
51
|
+
BOOST_FOREACH(string feature, parameters->featuresToRun)
|
52
|
+
{
|
53
|
+
if (!Utils::hasElement(dataFeatures, feature))
|
54
|
+
throw std::runtime_error("Feature '" + feature + "' specified as part of parameter 'featuresToRun', but feature not found in data");
|
55
|
+
}
|
56
|
+
}
|
57
|
+
if (parameters->featuresToRun.empty())
|
58
|
+
throw std::runtime_error("There are no features to run!");
|
59
|
+
|
60
|
+
if (m_data->missingValueDefined())
|
61
|
+
DecisionTreeNode::setMissingValue(m_data->getMissingValue());
|
62
|
+
|
63
|
+
}
|
64
|
+
|
65
|
+
void GBMRunner::estimateMore(int numTrees)
|
66
|
+
{
|
67
|
+
int numFolds = m_data->getNumFolds();
|
68
|
+
int numThreads = numFolds; // TODO: change this!
|
69
|
+
|
70
|
+
// tbb::task_scheduler_init init(numFolds);
|
71
|
+
// static tbb::simple_partitioner sp;
|
72
|
+
|
73
|
+
int grainSize = numFolds / numThreads;
|
74
|
+
|
75
|
+
// tbb::parallel_for(explicit_range<size_t>(0, numFolds, grainSize),
|
76
|
+
// [&](const explicit_range<size_t>& r) {
|
77
|
+
// int threadNumber = r.begin() / grainSize;
|
78
|
+
// for(size_t foldIndex=r.begin(); foldIndex!=r.end(); ++foldIndex)
|
79
|
+
for (int foldIndex = 0; foldIndex < numFolds; ++foldIndex)
|
80
|
+
{
|
81
|
+
vlcMessage.Begin("Estimating more...");
|
82
|
+
|
83
|
+
shared_ptr<GBMEstimator> estimator = dynamic_pointer_cast<GBMEstimator>(m_estimators.at(foldIndex));
|
84
|
+
estimator->estimateMore(numTrees);
|
85
|
+
|
86
|
+
vlcMessage.End();
|
87
|
+
}
|
88
|
+
// }, sp);
|
89
|
+
}
|
90
|
+
|
91
|
+
void GBMRunner::capTrees( int numTrees )
|
92
|
+
{
|
93
|
+
BOOST_FOREACH(shared_ptr<MLOutput>& output, m_outputObjects)
|
94
|
+
{
|
95
|
+
shared_ptr<GBMOutput> gbmOutput = dynamic_pointer_cast<GBMOutput>(output);
|
96
|
+
gbmOutput->capTrees(numTrees);
|
97
|
+
}
|
98
|
+
}
|
99
|
+
|
100
|
+
vector<FeatureInteraction> GBMRunner::getFeatureInteractions( int howMany )
|
101
|
+
{
|
102
|
+
config();
|
103
|
+
GBMEstimator gbmEstimator(m_data, m_data->getExperiments(), parameters);
|
104
|
+
return gbmEstimator.findInteractions(howMany);
|
105
|
+
}
|
106
|
+
|
107
|
+
shared_ptr<MLEstimator> GBMRunner::createEstimator(MLData* data, vector<shared_ptr<MLExperiment> > trainingExperiments)
|
108
|
+
{
|
109
|
+
return shared_ptr<MLEstimator>(shared_ptr<GBMEstimator>(new GBMEstimator(data, trainingExperiments, parameters)));
|
110
|
+
}
|
111
|
+
|
112
|
+
|
113
|
+
|
114
|
+
|
115
|
+
|
116
|
+
|
117
|
+
|
@@ -0,0 +1,94 @@
|
|
1
|
+
#include "MachineLearning/GBM/GaussianCalculator.h"
|
2
|
+
#include "MachineLearning/DecisionTree/DecisionTreeExperiment.h"
|
3
|
+
|
4
|
+
#include <boost/foreach.hpp>
|
5
|
+
|
6
|
+
GaussianCalculator::GaussianCalculator()
|
7
|
+
{
|
8
|
+
|
9
|
+
}
|
10
|
+
|
11
|
+
GaussianCalculator::~GaussianCalculator()
|
12
|
+
{
|
13
|
+
|
14
|
+
}
|
15
|
+
|
16
|
+
double GaussianCalculator::calculateDeviance(vector<shared_ptr<DecisionTreeExperiment> >& experiments)
|
17
|
+
{
|
18
|
+
double sumSquaredErrors = 0.0;
|
19
|
+
double sumWeight = 0.0;
|
20
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
21
|
+
{
|
22
|
+
double error = experiment->getY() - experiment->getPrediction();
|
23
|
+
sumSquaredErrors += experiment->getWeight() * error * error;
|
24
|
+
sumWeight += experiment->getWeight();
|
25
|
+
}
|
26
|
+
return sumSquaredErrors / sumWeight;
|
27
|
+
}
|
28
|
+
|
29
|
+
void GaussianCalculator::populateInitialF(vector<shared_ptr<DecisionTreeExperiment> >& experiments, bool useInitialPredictions)
|
30
|
+
{
|
31
|
+
if (!useInitialPredictions)
|
32
|
+
{
|
33
|
+
// compute mean
|
34
|
+
double sumY = 0.0;
|
35
|
+
double sumW = 0.0;
|
36
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
37
|
+
{
|
38
|
+
sumY += experiment->getWeight() * experiment->getY();
|
39
|
+
sumW += experiment->getWeight();
|
40
|
+
}
|
41
|
+
double meanY = sumY / sumW;
|
42
|
+
|
43
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
44
|
+
{
|
45
|
+
experiment->setPrediction(meanY);
|
46
|
+
}
|
47
|
+
}
|
48
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
49
|
+
{
|
50
|
+
experiment->setF(calculateF(experiment->getPrediction()));
|
51
|
+
}
|
52
|
+
}
|
53
|
+
|
54
|
+
void GaussianCalculator::updateZ(vector<shared_ptr<DecisionTreeExperiment> >& experiments)
|
55
|
+
{
|
56
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
57
|
+
{
|
58
|
+
double z = experiment->getY() - experiment->getPrediction();
|
59
|
+
experiment->setZ(z);
|
60
|
+
}
|
61
|
+
}
|
62
|
+
|
63
|
+
double GaussianCalculator::computeFIncrement(vector<shared_ptr<DecisionTreeExperiment> >& experiments)
|
64
|
+
{
|
65
|
+
double sumZ = 0.0;
|
66
|
+
double sumW = 0.0;
|
67
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
68
|
+
{
|
69
|
+
sumZ += experiment->getWeight() * experiment->getZ();
|
70
|
+
sumW += experiment->getWeight();
|
71
|
+
}
|
72
|
+
if (sumW == 0)
|
73
|
+
return 0.0;
|
74
|
+
|
75
|
+
return sumZ / sumW;
|
76
|
+
}
|
77
|
+
|
78
|
+
void GaussianCalculator::updatePredictions(vector<shared_ptr<DecisionTreeExperiment> >& experiments)
|
79
|
+
{
|
80
|
+
BOOST_FOREACH(shared_ptr<DecisionTreeExperiment>& experiment, experiments)
|
81
|
+
{
|
82
|
+
experiment->setPrediction(calculatePrediction(experiment->getF()));
|
83
|
+
}
|
84
|
+
}
|
85
|
+
|
86
|
+
double GaussianCalculator::calculatePrediction(double f)
|
87
|
+
{
|
88
|
+
return f;
|
89
|
+
}
|
90
|
+
|
91
|
+
double GaussianCalculator::calculateF(double prediction)
|
92
|
+
{
|
93
|
+
return prediction;
|
94
|
+
}
|
@@ -0,0 +1,317 @@
|
|
1
|
+
// #include "MachineLearning/GBM/ZenithGBM.h"
|
2
|
+
// #include "MachineLearning/GBM/GBMRunner.h"
|
3
|
+
// #include "MachineLearning/DecisionTree/FeatureInteraction.h"
|
4
|
+
// #include "MachineLearning/DecisionTree/SplitDefinition.h"
|
5
|
+
// #include "MachineLearning/MLData/MLData.h"
|
6
|
+
// #include "MachineLearning/gbm/GBMParameters.h"
|
7
|
+
|
8
|
+
// #include "stringConversion.h"
|
9
|
+
// #include "RubyUtils.h"
|
10
|
+
// using namespace RubyUtils;
|
11
|
+
|
12
|
+
|
13
|
+
|
14
|
+
// void zenith_gbm_Free(void* v)
|
15
|
+
// {
|
16
|
+
// delete (reinterpret_cast<GBMRunner*>(v));
|
17
|
+
// }
|
18
|
+
|
19
|
+
// OtInterface::VALUE zenith_gbm_New(int argc, VALUE* argv, VALUE klass)
|
20
|
+
// {
|
21
|
+
// VALUE obj = otRuby->DataWrapStruct(klass, 0, zenith_gbm_Free, 0);
|
22
|
+
// otRuby->rb_obj_call_init(obj, argc, argv);
|
23
|
+
// return obj;
|
24
|
+
// }
|
25
|
+
|
26
|
+
// OtInterface::VALUE zenith_gbm_Initialize(VALUE self)
|
27
|
+
// {
|
28
|
+
// if (otRuby->GetDataPtr(self)) zenith_gbm_Free(otRuby->GetDataPtr(self));
|
29
|
+
// otRuby->SetDataPtr(self, NULL);
|
30
|
+
|
31
|
+
// GBMRunner* gbm = new GBMRunner();
|
32
|
+
// if (gbm == NULL) otRuby->rb_sys_fail("ZenithGBM class could not be created");
|
33
|
+
// otRuby->SetDataPtr(self, gbm);
|
34
|
+
// return self;
|
35
|
+
// }
|
36
|
+
|
37
|
+
// OtInterface::VALUE zenith_gbm_estimate(VALUE self)
|
38
|
+
// {
|
39
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
40
|
+
// try
|
41
|
+
// {
|
42
|
+
// gbm->execute();
|
43
|
+
// }
|
44
|
+
// catch (std::exception e)
|
45
|
+
// {
|
46
|
+
// vlcMessage.Raise((string("Caught error: ") + e.what()).c_str());
|
47
|
+
// }
|
48
|
+
|
49
|
+
// return TOtRubyInterface::Qnil;
|
50
|
+
// }
|
51
|
+
|
52
|
+
|
53
|
+
// OtInterface::VALUE zenith_gbm_estimateMore(VALUE self, VALUE numTrees)
|
54
|
+
// {
|
55
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
56
|
+
// try
|
57
|
+
// {
|
58
|
+
// gbm->estimateMore(RubyUtils::fromValue<int>(numTrees));
|
59
|
+
// }
|
60
|
+
// catch (std::exception e)
|
61
|
+
// {
|
62
|
+
// vlcMessage.Raise((string("Caught error: ") + e.what()).c_str());
|
63
|
+
// }
|
64
|
+
|
65
|
+
// return TOtRubyInterface::Qnil;
|
66
|
+
// }
|
67
|
+
|
68
|
+
|
69
|
+
// OtInterface::VALUE zenith_gbm_setFeaturesToRun(VALUE self, VALUE featuresValue)
|
70
|
+
// {
|
71
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
72
|
+
// gbm->parameters->featuresToRun = RubyUtils::fromValue<vector<string> >(featuresValue);
|
73
|
+
// return TOtRubyInterface::Qnil;
|
74
|
+
// }
|
75
|
+
|
76
|
+
// OtInterface::VALUE zenith_gbm_setData(VALUE self, VALUE data)
|
77
|
+
// {
|
78
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
79
|
+
// MLData* mlData = (MLData*)otRuby->GetDataPtr(data);
|
80
|
+
// gbm->setData(mlData);
|
81
|
+
// return TOtRubyInterface::Qnil;
|
82
|
+
// }
|
83
|
+
|
84
|
+
// OtInterface::VALUE zenith_gbm_setTryMVariables(VALUE self, VALUE mVariablesValue)
|
85
|
+
// {
|
86
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
87
|
+
// gbm->parameters->tryMVariables = RubyUtils::fromValue<int>(mVariablesValue);
|
88
|
+
// return TOtRubyInterface::Qnil;
|
89
|
+
// }
|
90
|
+
|
91
|
+
// OtInterface::VALUE zenith_gbm_setKTerminalNodes(VALUE self, VALUE kNodesValue)
|
92
|
+
// {
|
93
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
94
|
+
// gbm->parameters->growKDecisionTreeNodes = RubyUtils::fromValue<int>(kNodesValue);
|
95
|
+
// return TOtRubyInterface::Qnil;
|
96
|
+
// }
|
97
|
+
|
98
|
+
// OtInterface::VALUE zenith_gbm_setNumIterations(VALUE self, VALUE numIterationsValue)
|
99
|
+
// {
|
100
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
101
|
+
// gbm->parameters->numIterations = RubyUtils::fromValue<int>(numIterationsValue);
|
102
|
+
// return TOtRubyInterface::Qnil;
|
103
|
+
// }
|
104
|
+
|
105
|
+
// OtInterface::VALUE zenith_gbm_setShrinkageFactor(VALUE self, VALUE shrinkageFactorValue)
|
106
|
+
// {
|
107
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
108
|
+
// gbm->parameters->shrinkageFactor = RubyUtils::fromValue<double>(shrinkageFactorValue);
|
109
|
+
// return TOtRubyInterface::Qnil;
|
110
|
+
// }
|
111
|
+
|
112
|
+
// OtInterface::VALUE zenith_gbm_setBagFraction(VALUE self, VALUE bagFractionValue)
|
113
|
+
// {
|
114
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
115
|
+
// gbm->parameters->bagFraction = RubyUtils::fromValue<double>(bagFractionValue);
|
116
|
+
// return TOtRubyInterface::Qnil;
|
117
|
+
// }
|
118
|
+
|
119
|
+
// OtInterface::VALUE zenith_gbm_setTrainingExperimentIds(VALUE self, VALUE experimentIdsValue)
|
120
|
+
// {
|
121
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
122
|
+
// gbm->parameters->trainingExperimentIds = RubyUtils::fromValue<vector<int> >(experimentIdsValue);
|
123
|
+
// return TOtRubyInterface::Qnil;
|
124
|
+
// }
|
125
|
+
|
126
|
+
// OtInterface::VALUE zenith_gbm_predictions(VALUE self, VALUE newMlData)
|
127
|
+
// {
|
128
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
129
|
+
// MLData* data = (MLData*)otRuby->GetDataPtr(newMlData);
|
130
|
+
|
131
|
+
// vector<double> predictions;
|
132
|
+
|
133
|
+
// try
|
134
|
+
// {
|
135
|
+
// predictions = gbm->getPredictions(data);
|
136
|
+
// }
|
137
|
+
// catch (std::exception e)
|
138
|
+
// {
|
139
|
+
// vlcMessage.Raise((string("Could not get predictions. Error: ") + e.what()).c_str());
|
140
|
+
// }
|
141
|
+
|
142
|
+
// return RubyUtils::toValue(predictions);
|
143
|
+
// }
|
144
|
+
|
145
|
+
// OtInterface::VALUE zenith_gbm_training_predictions(VALUE self)
|
146
|
+
// {
|
147
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
148
|
+
// vector<double> predictions;
|
149
|
+
|
150
|
+
// try
|
151
|
+
// {
|
152
|
+
// predictions = gbm->getMeanTrainingPredictions();
|
153
|
+
// }
|
154
|
+
// catch (std::exception e)
|
155
|
+
// {
|
156
|
+
// vlcMessage.Raise((string("Could not get training predictions. Error: ") + e.what()).c_str());
|
157
|
+
// }
|
158
|
+
|
159
|
+
// return RubyUtils::toValue(predictions);
|
160
|
+
// }
|
161
|
+
|
162
|
+
// OtInterface::VALUE zenith_gbm_crossvalidation_predictions(VALUE self)
|
163
|
+
// {
|
164
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
165
|
+
// vector<double> predictions;
|
166
|
+
|
167
|
+
// try
|
168
|
+
// {
|
169
|
+
// predictions = gbm->getCrossValidationPredictions();
|
170
|
+
// }
|
171
|
+
// catch (std::exception e)
|
172
|
+
// {
|
173
|
+
// vlcMessage.Raise((string("Could not get cross validation predictions. Error: ") + e.what()).c_str());
|
174
|
+
// }
|
175
|
+
|
176
|
+
// return RubyUtils::toValue(predictions);
|
177
|
+
// }
|
178
|
+
|
179
|
+
// OtInterface::VALUE zenith_gbm_minObservations(VALUE self, VALUE minObservations)
|
180
|
+
// {
|
181
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
182
|
+
// gbm->parameters->minObservations = RubyUtils::fromValue<int>(minObservations);
|
183
|
+
// return TOtRubyInterface::Qnil;
|
184
|
+
// }
|
185
|
+
|
186
|
+
// OtInterface::VALUE zenith_gbm_setDistribution(VALUE self, VALUE rb_distribution)
|
187
|
+
// {
|
188
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
189
|
+
// string distribution = stringToLower(RubyUtils::fromValue<string>(rb_distribution));
|
190
|
+
|
191
|
+
// if (distribution == "bernoulli")
|
192
|
+
// gbm->parameters->distribution = BERNOULLI;
|
193
|
+
// else if (distribution == "gaussian")
|
194
|
+
// gbm->parameters->distribution = GAUSSIAN;
|
195
|
+
// else
|
196
|
+
// throw std::invalid_argument("ZenithGBM::distribution = " + distribution);
|
197
|
+
|
198
|
+
// return TOtRubyInterface::Qnil;
|
199
|
+
// }
|
200
|
+
|
201
|
+
// OtInterface::VALUE zenith_gbm_verbose(VALUE self, VALUE rb_verbose)
|
202
|
+
// {
|
203
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
204
|
+
// bool verbose = RubyUtils::fromValue<bool>(rb_verbose);
|
205
|
+
// gbm->parameters->verbose = verbose;
|
206
|
+
// return TOtRubyInterface::Qnil;
|
207
|
+
// }
|
208
|
+
|
209
|
+
// OtInterface::VALUE zenith_gbm_setGreedy( VALUE self, VALUE rb_greedy )
|
210
|
+
// {
|
211
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
212
|
+
// bool greedy = RubyUtils::fromValue<bool>(rb_greedy);
|
213
|
+
// gbm->parameters->greedy = greedy;
|
214
|
+
// return TOtRubyInterface::Qnil;
|
215
|
+
// }
|
216
|
+
|
217
|
+
// OtInterface::VALUE zenith_gbm_setRfToLevel( VALUE self, VALUE rb_rfToLevel )
|
218
|
+
// {
|
219
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
220
|
+
// int rfToLevel = RubyUtils::fromValue<int>(rb_rfToLevel);
|
221
|
+
// gbm->parameters->rfToLevel = rfToLevel;
|
222
|
+
// return TOtRubyInterface::Qnil;
|
223
|
+
// }
|
224
|
+
|
225
|
+
// OtInterface::VALUE zenith_gbm_capTrees( VALUE self, VALUE rb_cap )
|
226
|
+
// {
|
227
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
228
|
+
// int cap = RubyUtils::fromValue<int>(rb_cap);
|
229
|
+
// gbm->capTrees(cap);
|
230
|
+
// return TOtRubyInterface::Qnil;
|
231
|
+
// }
|
232
|
+
|
233
|
+
// OtInterface::VALUE zenith_gbm_setScale( VALUE self, VALUE rb_scale )
|
234
|
+
// {
|
235
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
236
|
+
// double scale = RubyUtils::fromValue<double>(rb_scale);
|
237
|
+
// gbm->parameters->scale = scale;;
|
238
|
+
// return TOtRubyInterface::Qnil;
|
239
|
+
// }
|
240
|
+
|
241
|
+
// OtInterface::VALUE splitValueOrCategories(MLData* data, shared_ptr<SplitDefinition> splitDefinition, Partition partition)
|
242
|
+
// {
|
243
|
+
// VALUE returnValue;
|
244
|
+
// if (splitDefinition->isCategorical())
|
245
|
+
// {
|
246
|
+
// // categories as an array
|
247
|
+
// if (partition == LHS)
|
248
|
+
// returnValue = RubyUtils::toValue(splitDefinition->getLhsCategories());
|
249
|
+
// else if (partition == RHS)
|
250
|
+
// returnValue = RubyUtils::toValue(splitDefinition->getRhsCategories());
|
251
|
+
// else if (partition == MISSING)
|
252
|
+
// {
|
253
|
+
// set<double> setMissingValue;
|
254
|
+
// setMissingValue.insert(data->getMissingValue());
|
255
|
+
// returnValue = RubyUtils::toValue(setMissingValue);
|
256
|
+
// }
|
257
|
+
// else
|
258
|
+
// throw std::runtime_error("Primary partition should be either LHS, RHS or MISSING!");
|
259
|
+
// }
|
260
|
+
// else
|
261
|
+
// {
|
262
|
+
// // split value as a double
|
263
|
+
// returnValue = RubyUtils::toValue(splitDefinition->getSplitValue());
|
264
|
+
// }
|
265
|
+
// return returnValue;
|
266
|
+
// }
|
267
|
+
|
268
|
+
// OtInterface::VALUE zenith_gbm_getFeatureInteractions( VALUE self, VALUE howMany )
|
269
|
+
// {
|
270
|
+
// GBMRunner* gbm = (GBMRunner*)otRuby->GetDataPtr(self);
|
271
|
+
// vector<FeatureInteraction> featureInteractions;
|
272
|
+
|
273
|
+
// try
|
274
|
+
// {
|
275
|
+
// featureInteractions = gbm->getFeatureInteractions(RubyUtils::fromValue<int>(howMany));
|
276
|
+
// }
|
277
|
+
// catch (std::exception e)
|
278
|
+
// {
|
279
|
+
// vlcMessage.Raise((string("Could not get calculate interactions. Error: ") + e.what()).c_str());
|
280
|
+
// }
|
281
|
+
// MLData* data = gbm->getData();
|
282
|
+
// vector<string> featureNames = data->getFeatures();
|
283
|
+
|
284
|
+
// vector<vector<VALUE> > returnVector;
|
285
|
+
// returnVector.reserve(featureInteractions.size());
|
286
|
+
|
287
|
+
// BOOST_FOREACH(auto& interaction, featureInteractions)
|
288
|
+
// {
|
289
|
+
// vector<VALUE> v;
|
290
|
+
// v.reserve(6);
|
291
|
+
|
292
|
+
// // improvement
|
293
|
+
// v.push_back(RubyUtils::toValue(interaction.secondarySplitDefinition->getImprovement()));
|
294
|
+
|
295
|
+
// // primary feature name
|
296
|
+
// v.push_back(RubyUtils::toValue(featureNames.at(interaction.primarySplitDefinition->getFeatureIndex())));
|
297
|
+
|
298
|
+
// // either the split value (as double), or the categories
|
299
|
+
// v.push_back(splitValueOrCategories(data, interaction.primarySplitDefinition, interaction.primaryPartition));
|
300
|
+
|
301
|
+
// // the partition chosen
|
302
|
+
// Partition p = interaction.primaryPartition;
|
303
|
+
// int partition = (p == LHS ? 1 : (p == RHS ? 2 : 3));
|
304
|
+
// v.push_back(RubyUtils::toValue(partition));
|
305
|
+
|
306
|
+
// // second feature name
|
307
|
+
// v.push_back(RubyUtils::toValue(featureNames.at(interaction.secondarySplitDefinition->getFeatureIndex())));
|
308
|
+
|
309
|
+
// // secondary split value / left hand side categories
|
310
|
+
// v.push_back(splitValueOrCategories(data, interaction.secondarySplitDefinition, LHS));
|
311
|
+
|
312
|
+
// // no need for a secondary partition, as all children of the second partition are important
|
313
|
+
// returnVector.push_back(v);
|
314
|
+
// }
|
315
|
+
// return RubyUtils::toValue(returnVector);
|
316
|
+
// }
|
317
|
+
|