gbrl 1.0.0.dev6__tar.gz → 1.0.0.dev7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gbrl-1.0.0.dev6/gbrl.egg-info → gbrl-1.0.0.dev7}/PKG-INFO +4 -3
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/README.md +3 -2
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/gbrl_wrapper.py +14 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/gbt.py +16 -1
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/gbrl.cpp +22 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/gbrl.h +1 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/gbrl_binding.cpp +4 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/types.cpp +162 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/types.h +2 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7/gbrl.egg-info}/PKG-INFO +4 -3
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/pyproject.toml +1 -1
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/LICENSE +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/MANIFEST.in +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/__init__.py +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/ac_gbrl.py +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/config.py +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/config.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/fitter.cpp +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/fitter.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/loss.cpp +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/loss.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/main.cpp +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/math_ops.cpp +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/math_ops.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/node.cpp +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/node.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/optimizer.cpp +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/optimizer.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/predictor.cpp +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/predictor.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/scheduler.cpp +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/scheduler.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/split_candidate_generator.cpp +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/split_candidate_generator.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/utils.cpp +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cpp/utils.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_fitter.cu +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_fitter.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_loss.cu +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_loss.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_predictor.cu +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_predictor.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_preprocess.cu +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_preprocess.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_types.cu +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_types.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_utils.cu +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_utils.h +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl/utils.py +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl.egg-info/SOURCES.txt +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl.egg-info/dependency_links.txt +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl.egg-info/requires.txt +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/gbrl.egg-info/top_level.txt +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/setup.cfg +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/setup.py +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/tests/test_gbt_multi.py +0 -0
- {gbrl-1.0.0.dev6 → gbrl-1.0.0.dev7}/tests/test_gbt_single.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: gbrl
|
|
3
|
-
Version: 1.0.0.
|
|
3
|
+
Version: 1.0.0.dev7
|
|
4
4
|
Summary: Gradient Boosted Trees for RL
|
|
5
5
|
Author-email: Benjamin Fuhrer <bfuhrer@nvidia.com>, Chen Tessler <ctessler@nvidia.com>, Gal Dalal <galal@nvidia.com>
|
|
6
6
|
Classifier: Development Status :: 4 - Beta
|
|
@@ -33,10 +33,11 @@ GBRL is a Python-based GBT library designed and optimized for reinforcement lear
|
|
|
33
33
|
## Getting started
|
|
34
34
|
|
|
35
35
|
### Dependencies
|
|
36
|
+
#### MAC OS
|
|
37
|
+
```
|
|
36
38
|
llvm
|
|
37
39
|
openmp
|
|
38
|
-
|
|
39
|
-
#### MAC OS
|
|
40
|
+
```
|
|
40
41
|
|
|
41
42
|
Make sure to run:
|
|
42
43
|
```
|
|
@@ -126,6 +126,16 @@ class GBTWrapper:
|
|
|
126
126
|
status = self.model.save(filename)
|
|
127
127
|
assert status == 0, "Failed to save model"
|
|
128
128
|
|
|
129
|
+
def export(self, filename: str, modelname: str = None) -> None:
|
|
130
|
+
# exports model to C
|
|
131
|
+
filename = filename.rstrip('.')
|
|
132
|
+
filename += '.h'
|
|
133
|
+
assert self.model is not None, "Can't export non-existent model!"
|
|
134
|
+
if modelname is None:
|
|
135
|
+
modelname = ""
|
|
136
|
+
status = self.model.export(filename, modelname)
|
|
137
|
+
assert status == 0, "Failed to export model"
|
|
138
|
+
|
|
129
139
|
@classmethod
|
|
130
140
|
def load(cls, filename: str) -> "GBTWrapper":
|
|
131
141
|
filename = filename.rstrip('.')
|
|
@@ -313,6 +323,10 @@ class SeparateActorCriticWrapper:
|
|
|
313
323
|
self.policy_model.save(filename + '_policy')
|
|
314
324
|
self.value_model.save(filename + '_value')
|
|
315
325
|
|
|
326
|
+
def export(self, filename: str) -> None:
|
|
327
|
+
self.policy_model.export(filename + '_policy')
|
|
328
|
+
self.value_model.export(filename + '_value')
|
|
329
|
+
|
|
316
330
|
@classmethod
|
|
317
331
|
def load(cls, filename: str) -> "SeparateActorCriticWrapper":
|
|
318
332
|
instance = cls.__new__(cls)
|
|
@@ -172,9 +172,24 @@ class GradientBoostingTrees:
|
|
|
172
172
|
return self._model.get_num_trees()
|
|
173
173
|
|
|
174
174
|
|
|
175
|
-
def save_model(self, save_path: str
|
|
175
|
+
def save_model(self, save_path: str) -> None:
|
|
176
|
+
"""
|
|
177
|
+
Saves model to file
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
filename (str): Absolute path and name of save filename.
|
|
181
|
+
"""
|
|
176
182
|
self._model.save(save_path)
|
|
177
183
|
|
|
184
|
+
def export_model(self, filename: str, modelname: str = None) -> None:
|
|
185
|
+
"""
|
|
186
|
+
Exports model as a C-header file
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
filename (str): Absolute path and name of exported filename.
|
|
190
|
+
"""
|
|
191
|
+
self._model.export(filename, modelname)
|
|
192
|
+
|
|
178
193
|
@classmethod
|
|
179
194
|
def load_model(cls, load_name: str):
|
|
180
195
|
instance = cls.__new__(cls)
|
|
@@ -682,6 +682,28 @@ float GBRL::fit(float *obs, char *categorical_obs, float *targets, int iteration
|
|
|
682
682
|
return full_loss;
|
|
683
683
|
}
|
|
684
684
|
|
|
685
|
+
int GBRL::exportModel(const std::string& filename, const std::string& modelname){
|
|
686
|
+
std::ofstream header_file(filename, std::ios::binary);
|
|
687
|
+
if (!header_file.is_open() || header_file.fail()) {
|
|
688
|
+
std::cerr << "Error opening file: " << filename << std::endl;
|
|
689
|
+
throw std::runtime_error("File opening error");
|
|
690
|
+
return -1;
|
|
691
|
+
}
|
|
692
|
+
if (this->metadata->grow_policy != OBLIVIOUS) {
|
|
693
|
+
std::cerr << "Export is supported only for Oblivious trees." << std::endl;
|
|
694
|
+
header_file.close();
|
|
695
|
+
return -1;
|
|
696
|
+
}
|
|
697
|
+
export_ensemble_data(header_file, modelname, this->edata, this->metadata, this->device, this->opts);
|
|
698
|
+
if (!header_file.good()) {
|
|
699
|
+
std::cerr << "Error occurred at writing time." << std::endl;
|
|
700
|
+
throw std::runtime_error("Writing to file error");
|
|
701
|
+
return -1;
|
|
702
|
+
}
|
|
703
|
+
|
|
704
|
+
header_file.close();
|
|
705
|
+
return 0;
|
|
706
|
+
}
|
|
685
707
|
|
|
686
708
|
int GBRL::saveToFile(const std::string& filename){
|
|
687
709
|
std::ofstream file(filename, std::ios::binary);
|
|
@@ -30,6 +30,7 @@ class GBRL {
|
|
|
30
30
|
void to_device(deviceType device);
|
|
31
31
|
std::string get_device();
|
|
32
32
|
int saveToFile(const std::string& filename);
|
|
33
|
+
int exportModel(const std::string& filename, const std::string& modelname);
|
|
33
34
|
int loadFromFile(const std::string& filename);
|
|
34
35
|
|
|
35
36
|
void step(const float *obs, const char *categorical_obs, float *grads, const int n_samples, const int n_num_features, const int n_cat_features);
|
|
@@ -258,6 +258,10 @@ PYBIND11_MODULE(gbrl_cpp, m) {
|
|
|
258
258
|
py::gil_scoped_release release;
|
|
259
259
|
return self.saveToFile(filename);
|
|
260
260
|
}, "Save the model to a file");
|
|
261
|
+
gbrl.def("export", [](GBRL &self, const std::string& filename, const std::string& modelname) -> int {
|
|
262
|
+
py::gil_scoped_release release;
|
|
263
|
+
return self.exportModel(filename, modelname);
|
|
264
|
+
}, py::arg("filename"), py::arg("modelname") = "", "Export model as a C-header file");
|
|
261
265
|
gbrl.def("get_scheduler_lrs", [](GBRL &self) -> std::tuple<float, float> {
|
|
262
266
|
py::gil_scoped_release release;
|
|
263
267
|
return self.get_scheduler_lrs();
|
|
@@ -6,6 +6,7 @@
|
|
|
6
6
|
#include <stdexcept>
|
|
7
7
|
|
|
8
8
|
#include "types.h"
|
|
9
|
+
#include "optimizer.h"
|
|
9
10
|
#ifdef USE_CUDA
|
|
10
11
|
#include "cuda_types.h"
|
|
11
12
|
#endif
|
|
@@ -282,6 +283,161 @@ void ensemble_data_dealloc(ensembleData *edata){
|
|
|
282
283
|
delete edata;
|
|
283
284
|
}
|
|
284
285
|
|
|
286
|
+
void export_ensemble_data(std::ofstream& header_file, const std::string& model_name, ensembleData *edata, ensembleMetaData *metadata, deviceType device, std::vector<Optimizer*> opts)
|
|
287
|
+
{
|
|
288
|
+
if (!header_file.is_open() || header_file.fail()) {
|
|
289
|
+
std::cerr << "Error file is not open for writing: " << std::endl;
|
|
290
|
+
throw std::runtime_error("Error opening header_file");
|
|
291
|
+
}
|
|
292
|
+
ensembleData *edata_cpu = nullptr;
|
|
293
|
+
#ifdef USE_CUDA
|
|
294
|
+
if (device == gpu){
|
|
295
|
+
edata_cpu = ensemble_data_copy_gpu_cpu(metadata, edata);
|
|
296
|
+
}
|
|
297
|
+
#endif
|
|
298
|
+
if (device == cpu)
|
|
299
|
+
edata_cpu = edata;
|
|
300
|
+
|
|
301
|
+
int binary_splits = 0;
|
|
302
|
+
for (int i = 0; i < metadata->n_trees; ++i){
|
|
303
|
+
binary_splits += edata_cpu->depths[i];
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
for (size_t opt_idx = 0; opt_idx < opts.size(); ++opt_idx){
|
|
307
|
+
optimizerAlgo algo = opts[opt_idx]->getAlgo();
|
|
308
|
+
if (algo != SGD){
|
|
309
|
+
std::cerr << "Error. Can only export SGD optimizers" << std::endl;
|
|
310
|
+
return;
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
header_file << "#ifndef GBRL_MODEL_H\n";
|
|
315
|
+
header_file << "#define GBRL_MODEL_H\n\n";
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
header_file << "/*\n";
|
|
319
|
+
|
|
320
|
+
if (!model_name.empty()) {
|
|
321
|
+
header_file << "###########################\n";
|
|
322
|
+
header_file << "model_name: " << model_name << "\n";
|
|
323
|
+
}
|
|
324
|
+
header_file << "###########################\n";
|
|
325
|
+
header_file << "n_leaves: " << metadata->n_leaves << ", ";
|
|
326
|
+
header_file << "n_trees: " << metadata->n_trees << ", ";
|
|
327
|
+
header_file << "max_trees: " << metadata->max_trees << ", ";
|
|
328
|
+
header_file << "max_leaves: " << metadata->max_leaves << ", ";
|
|
329
|
+
header_file << "max_trees_batch: " << metadata->max_trees_batch << ", ";
|
|
330
|
+
header_file << "max_leaves_batch: " << metadata->max_leaves_batch << ", ";
|
|
331
|
+
header_file << "output_dim: " << metadata->output_dim << ", ";
|
|
332
|
+
header_file << "policy_dim: " << metadata->policy_dim;
|
|
333
|
+
header_file << "\nmax_depth: " << metadata->max_depth << ", ";
|
|
334
|
+
header_file << "min_data_in_leaf: " << metadata->min_data_in_leaf << ", ";
|
|
335
|
+
header_file << "n_bins: " << metadata->n_bins << ", ";
|
|
336
|
+
header_file << "par_th: " << metadata->par_th << ", ";
|
|
337
|
+
header_file << "cv_beta: " << metadata->cv_beta << ", ";
|
|
338
|
+
header_file << "verbose: " << metadata->verbose << ", ";
|
|
339
|
+
header_file << "batch_size: " << metadata->batch_size << ", ";
|
|
340
|
+
header_file << "use_cv: " << metadata->use_cv;
|
|
341
|
+
header_file << "\nsplit_score_func: " << scoreFuncToString(metadata->split_score_func) << ", ";
|
|
342
|
+
header_file << "generator_type: " << generatorTypeToString(metadata->generator_type) << ", ";
|
|
343
|
+
header_file << "grow_policy: " << growPolicyToString(metadata->grow_policy) << ", ";
|
|
344
|
+
header_file << "n_num_features: " << metadata->n_num_features << ", ";
|
|
345
|
+
header_file << "n_cat_features: " << metadata->n_cat_features << ", ";
|
|
346
|
+
header_file << "iteration: " << metadata->iteration;
|
|
347
|
+
header_file << "\n*/\n";
|
|
348
|
+
|
|
349
|
+
header_file << "#define N_TREES " << metadata->n_trees << "\n";
|
|
350
|
+
header_file << "#define N_LEAVES " << metadata->n_leaves << "\n";
|
|
351
|
+
header_file << "#define BINARY_FEATURES " << binary_splits << "\n";
|
|
352
|
+
header_file << "#define N_OUTPUTS " << metadata->output_dim << "\n";
|
|
353
|
+
header_file << "#define N_FEATURES " << metadata->n_num_features << "\n\n";
|
|
354
|
+
|
|
355
|
+
header_file << "static inline void gbrl_predict(float *results, const float *features){\n\n";
|
|
356
|
+
header_file << "\tunsigned int j, tree_idx, depth, current_depth, idx, leaf_ptr, cond_ptr;\n";
|
|
357
|
+
header_file << "\t/* Model data */\n";
|
|
358
|
+
header_file << "\tconst unsigned int depths[N_TREES] = {";
|
|
359
|
+
for (int i = 0; i < metadata->n_trees; ++i){
|
|
360
|
+
header_file << edata_cpu->depths[i];
|
|
361
|
+
if (i < metadata->n_trees - 1)
|
|
362
|
+
header_file << ", ";
|
|
363
|
+
}
|
|
364
|
+
header_file << "};\n";
|
|
365
|
+
header_file << "\tconst float bias[N_OUTPUTS] = {";
|
|
366
|
+
for (int i = 0; i < metadata->output_dim; ++i){
|
|
367
|
+
header_file << edata_cpu->bias[i];
|
|
368
|
+
if (i < metadata->output_dim - 1)
|
|
369
|
+
header_file << ", ";
|
|
370
|
+
}
|
|
371
|
+
header_file << "};\n";
|
|
372
|
+
header_file << "\tconst unsigned int feature_indices[BINARY_FEATURES] = {";
|
|
373
|
+
for (int i = 0; i < binary_splits; ++i){
|
|
374
|
+
header_file << edata_cpu->feature_indices[i];
|
|
375
|
+
if (i < binary_splits - 1)
|
|
376
|
+
header_file << ", ";
|
|
377
|
+
}
|
|
378
|
+
header_file << "};\n";
|
|
379
|
+
header_file << "\tconst float feature_values[BINARY_FEATURES] = {";
|
|
380
|
+
for (int i = 0; i < binary_splits; ++i){
|
|
381
|
+
header_file << edata_cpu->feature_values[i];
|
|
382
|
+
if (i < binary_splits - 1)
|
|
383
|
+
header_file << ", ";
|
|
384
|
+
}
|
|
385
|
+
header_file << "};\n";
|
|
386
|
+
header_file << "\tconst float leaf_values[N_LEAVES*N_OUTPUTS] = {";
|
|
387
|
+
int tree_idx = 0;
|
|
388
|
+
int limit_leaf_idx = edata_cpu->tree_indices[tree_idx];
|
|
389
|
+
float value;
|
|
390
|
+
for (int i = 0; i < metadata->n_leaves; ++i){
|
|
391
|
+
if (i > limit_leaf_idx){
|
|
392
|
+
tree_idx += 1;
|
|
393
|
+
limit_leaf_idx = edata_cpu->tree_indices[tree_idx];
|
|
394
|
+
}
|
|
395
|
+
int value_idx = i*metadata->output_dim;
|
|
396
|
+
for (size_t opt_idx = 0; opt_idx < opts.size(); ++opt_idx){
|
|
397
|
+
for (int j=opts[opt_idx]->start_idx; j < opts[opt_idx]->end_idx; ++j){
|
|
398
|
+
value = -edata_cpu->values[value_idx + j] * opts[opt_idx]->scheduler->get_lr(tree_idx);
|
|
399
|
+
header_file << value;
|
|
400
|
+
if ((i < metadata->n_leaves - 1) || (j < metadata->output_dim - 1 && i == metadata->n_leaves - 1))
|
|
401
|
+
header_file << ", ";
|
|
402
|
+
}
|
|
403
|
+
}
|
|
404
|
+
}
|
|
405
|
+
header_file << "};\n";
|
|
406
|
+
// header_file << "\tconst unsigned int tree_indices[N_TREES] = {";
|
|
407
|
+
// for (int i = 0; i < metadata->n_trees; ++i){
|
|
408
|
+
// header_file << edata_cpu->tree_indices[i];
|
|
409
|
+
// if (i < metadata->n_trees - 1)
|
|
410
|
+
// header_file << ", ";
|
|
411
|
+
// }
|
|
412
|
+
// header_file << "};\n";
|
|
413
|
+
header_file << "\tleaf_ptr = 0;\n";
|
|
414
|
+
header_file << "\tcond_ptr = 0;\n";
|
|
415
|
+
header_file << "\tunsigned char pass;\n";
|
|
416
|
+
header_file << "\tfor (tree_idx = 0; tree_idx < N_TREES; ++tree_idx)\n";
|
|
417
|
+
header_file << "\t{\n";
|
|
418
|
+
header_file << "\t\tcurrent_depth = depths[tree_idx];\n";
|
|
419
|
+
header_file << "\t\tidx = 0;\n";
|
|
420
|
+
header_file << "\t\tfor (depth = 0; depth < current_depth; ++depth){\n";
|
|
421
|
+
header_file << "\t\t\tpass = (unsigned char)(features[feature_indices[cond_ptr + depth]] > feature_values[cond_ptr + depth]);\n";
|
|
422
|
+
header_file << "\t\t\tidx |= (pass << (current_depth - 1 - depth));\n";
|
|
423
|
+
header_file << "\t\t}\n";
|
|
424
|
+
header_file << "\t\tfor (j = 0 ; j < N_OUTPUTS; j++)\n";
|
|
425
|
+
header_file << "\t\t\tresults[j] += leaf_values[(leaf_ptr + idx)*N_OUTPUTS + j];\n";
|
|
426
|
+
header_file << "\t\tleaf_ptr += (1 << current_depth);\n";
|
|
427
|
+
header_file << "\t\tcond_ptr += current_depth;\n";
|
|
428
|
+
header_file << "\t}\n";
|
|
429
|
+
header_file << "\tfor (j = 0 ; j < N_OUTPUTS; j++)\n";
|
|
430
|
+
header_file << "\t\tresults[j] += bias[j];\n";
|
|
431
|
+
header_file << "}\n";
|
|
432
|
+
header_file << "#endif\n";
|
|
433
|
+
|
|
434
|
+
#ifdef USE_CUDA
|
|
435
|
+
if (device == gpu){
|
|
436
|
+
ensemble_data_dealloc(edata_cpu);
|
|
437
|
+
}
|
|
438
|
+
#endif
|
|
439
|
+
}
|
|
440
|
+
|
|
285
441
|
void save_ensemble_data(std::ofstream& file, ensembleData *edata, ensembleMetaData *metadata, deviceType device){
|
|
286
442
|
if (!file.is_open() || file.fail()) {
|
|
287
443
|
std::cerr << "Error file is not open for writing: " << std::endl;
|
|
@@ -338,6 +494,12 @@ void save_ensemble_data(std::ofstream& file, ensembleData *edata, ensembleMetaDa
|
|
|
338
494
|
file.write(reinterpret_cast<char*>(&check), sizeof(NULL_CHECK));
|
|
339
495
|
if (edata_cpu->categorical_values != nullptr)
|
|
340
496
|
file.write(reinterpret_cast<char*>(edata_cpu->categorical_values), metadata->max_depth * sizes * sizeof(char) * MAX_CHAR_SIZE);
|
|
497
|
+
|
|
498
|
+
#ifdef USE_CUDA
|
|
499
|
+
if (device == gpu){
|
|
500
|
+
ensemble_data_dealloc(edata_cpu);
|
|
501
|
+
}
|
|
502
|
+
#endif
|
|
341
503
|
}
|
|
342
504
|
|
|
343
505
|
ensembleData* load_ensemble_data(std::ifstream& file, ensembleMetaData *metadata){
|
|
@@ -10,6 +10,7 @@
|
|
|
10
10
|
#define TREES_BATCH 25000 // 100 K
|
|
11
11
|
#define MAX_CHAR_SIZE 128
|
|
12
12
|
|
|
13
|
+
class Optimizer;
|
|
13
14
|
struct splitCondition {
|
|
14
15
|
int feature_idx;
|
|
15
16
|
float feature_value;
|
|
@@ -165,6 +166,7 @@ ensembleData* ensemble_copy_data_alloc(ensembleMetaData *metadata);
|
|
|
165
166
|
ensembleData* copy_ensemble_data(ensembleData *other_edata, ensembleMetaData *metadata);
|
|
166
167
|
void ensemble_data_dealloc(ensembleData *edata);
|
|
167
168
|
void save_ensemble_data(std::ofstream& file, ensembleData *edata, ensembleMetaData *metadata, deviceType device);
|
|
169
|
+
void export_ensemble_data(std::ofstream& header_file, const std::string& model_name, ensembleData *edata, ensembleMetaData *metadata, deviceType device, std::vector<Optimizer*> opts);
|
|
168
170
|
ensembleData* load_ensemble_data(std::ifstream& file, ensembleMetaData *metadata);
|
|
169
171
|
void allocate_ensemble_memory(ensembleMetaData *metadata, ensembleData *edata);
|
|
170
172
|
#endif
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: gbrl
|
|
3
|
-
Version: 1.0.0.
|
|
3
|
+
Version: 1.0.0.dev7
|
|
4
4
|
Summary: Gradient Boosted Trees for RL
|
|
5
5
|
Author-email: Benjamin Fuhrer <bfuhrer@nvidia.com>, Chen Tessler <ctessler@nvidia.com>, Gal Dalal <galal@nvidia.com>
|
|
6
6
|
Classifier: Development Status :: 4 - Beta
|
|
@@ -33,10 +33,11 @@ GBRL is a Python-based GBT library designed and optimized for reinforcement lear
|
|
|
33
33
|
## Getting started
|
|
34
34
|
|
|
35
35
|
### Dependencies
|
|
36
|
+
#### MAC OS
|
|
37
|
+
```
|
|
36
38
|
llvm
|
|
37
39
|
openmp
|
|
38
|
-
|
|
39
|
-
#### MAC OS
|
|
40
|
+
```
|
|
40
41
|
|
|
41
42
|
Make sure to run:
|
|
42
43
|
```
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|