gbrl 1.0.0.dev5__tar.gz → 1.0.0.dev7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {gbrl-1.0.0.dev5/gbrl.egg-info → gbrl-1.0.0.dev7}/PKG-INFO +4 -3
  2. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/README.md +3 -2
  3. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/gbrl_wrapper.py +14 -0
  4. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/gbt.py +16 -1
  5. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/gbrl.cpp +22 -0
  6. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/gbrl.h +1 -0
  7. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/gbrl_binding.cpp +4 -0
  8. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/types.cpp +162 -0
  9. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/types.h +2 -0
  10. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7/gbrl.egg-info}/PKG-INFO +4 -3
  11. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/pyproject.toml +1 -1
  12. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/LICENSE +0 -0
  13. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/MANIFEST.in +0 -0
  14. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/__init__.py +0 -0
  15. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/ac_gbrl.py +0 -0
  16. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/config.py +0 -0
  17. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/config.h +0 -0
  18. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/fitter.cpp +0 -0
  19. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/fitter.h +0 -0
  20. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/loss.cpp +0 -0
  21. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/loss.h +0 -0
  22. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/main.cpp +0 -0
  23. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/math_ops.cpp +0 -0
  24. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/math_ops.h +0 -0
  25. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/node.cpp +0 -0
  26. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/node.h +0 -0
  27. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/optimizer.cpp +0 -0
  28. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/optimizer.h +0 -0
  29. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/predictor.cpp +0 -0
  30. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/predictor.h +0 -0
  31. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/scheduler.cpp +0 -0
  32. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/scheduler.h +0 -0
  33. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/split_candidate_generator.cpp +0 -0
  34. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/split_candidate_generator.h +0 -0
  35. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/utils.cpp +0 -0
  36. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cpp/utils.h +0 -0
  37. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_fitter.cu +0 -0
  38. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_fitter.h +0 -0
  39. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_loss.cu +0 -0
  40. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_loss.h +0 -0
  41. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_predictor.cu +0 -0
  42. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_predictor.h +0 -0
  43. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_preprocess.cu +0 -0
  44. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_preprocess.h +0 -0
  45. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_types.cu +0 -0
  46. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_types.h +0 -0
  47. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_utils.cu +0 -0
  48. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/src/cuda/cuda_utils.h +0 -0
  49. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl/utils.py +0 -0
  50. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl.egg-info/SOURCES.txt +0 -0
  51. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl.egg-info/dependency_links.txt +0 -0
  52. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl.egg-info/requires.txt +0 -0
  53. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/gbrl.egg-info/top_level.txt +0 -0
  54. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/setup.cfg +0 -0
  55. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/setup.py +0 -0
  56. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/tests/test_gbt_multi.py +0 -0
  57. {gbrl-1.0.0.dev5 → gbrl-1.0.0.dev7}/tests/test_gbt_single.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gbrl
3
- Version: 1.0.0.dev5
3
+ Version: 1.0.0.dev7
4
4
  Summary: Gradient Boosted Trees for RL
5
5
  Author-email: Benjamin Fuhrer <bfuhrer@nvidia.com>, Chen Tessler <ctessler@nvidia.com>, Gal Dalal <galal@nvidia.com>
6
6
  Classifier: Development Status :: 4 - Beta
@@ -33,10 +33,11 @@ GBRL is a Python-based GBT library designed and optimized for reinforcement lear
33
33
  ## Getting started
34
34
 
35
35
  ### Dependencies
36
+ #### MAC OS
37
+ ```
36
38
  llvm
37
39
  openmp
38
-
39
- #### MAC OS
40
+ ```
40
41
 
41
42
  Make sure to run:
42
43
  ```
@@ -11,10 +11,11 @@ GBRL is a Python-based GBT library designed and optimized for reinforcement lear
11
11
  ## Getting started
12
12
 
13
13
  ### Dependencies
14
+ #### MAC OS
15
+ ```
14
16
  llvm
15
17
  openmp
16
-
17
- #### MAC OS
18
+ ```
18
19
 
19
20
  Make sure to run:
20
21
  ```
@@ -126,6 +126,16 @@ class GBTWrapper:
126
126
  status = self.model.save(filename)
127
127
  assert status == 0, "Failed to save model"
128
128
 
129
+ def export(self, filename: str, modelname: str = None) -> None:
130
+ # exports model to C
131
+ filename = filename.rstrip('.')
132
+ filename += '.h'
133
+ assert self.model is not None, "Can't export non-existent model!"
134
+ if modelname is None:
135
+ modelname = ""
136
+ status = self.model.export(filename, modelname)
137
+ assert status == 0, "Failed to export model"
138
+
129
139
  @classmethod
130
140
  def load(cls, filename: str) -> "GBTWrapper":
131
141
  filename = filename.rstrip('.')
@@ -313,6 +323,10 @@ class SeparateActorCriticWrapper:
313
323
  self.policy_model.save(filename + '_policy')
314
324
  self.value_model.save(filename + '_value')
315
325
 
326
+ def export(self, filename: str) -> None:
327
+ self.policy_model.export(filename + '_policy')
328
+ self.value_model.export(filename + '_value')
329
+
316
330
  @classmethod
317
331
  def load(cls, filename: str) -> "SeparateActorCriticWrapper":
318
332
  instance = cls.__new__(cls)
@@ -172,9 +172,24 @@ class GradientBoostingTrees:
172
172
  return self._model.get_num_trees()
173
173
 
174
174
 
175
- def save_model(self, save_path: str=None) -> None:
175
+ def save_model(self, save_path: str) -> None:
176
+ """
177
+ Saves model to file
178
+
179
+ Args:
180
+ filename (str): Absolute path and name of save filename.
181
+ """
176
182
  self._model.save(save_path)
177
183
 
184
+ def export_model(self, filename: str, modelname: str = None) -> None:
185
+ """
186
+ Exports model as a C-header file
187
+
188
+ Args:
189
+ filename (str): Absolute path and name of exported filename.
190
+ """
191
+ self._model.export(filename, modelname)
192
+
178
193
  @classmethod
179
194
  def load_model(cls, load_name: str):
180
195
  instance = cls.__new__(cls)
@@ -682,6 +682,28 @@ float GBRL::fit(float *obs, char *categorical_obs, float *targets, int iteration
682
682
  return full_loss;
683
683
  }
684
684
 
685
+ int GBRL::exportModel(const std::string& filename, const std::string& modelname){
686
+ std::ofstream header_file(filename, std::ios::binary);
687
+ if (!header_file.is_open() || header_file.fail()) {
688
+ std::cerr << "Error opening file: " << filename << std::endl;
689
+ throw std::runtime_error("File opening error");
690
+ return -1;
691
+ }
692
+ if (this->metadata->grow_policy != OBLIVIOUS) {
693
+ std::cerr << "Export is supported only for Oblivious trees." << std::endl;
694
+ header_file.close();
695
+ return -1;
696
+ }
697
+ export_ensemble_data(header_file, modelname, this->edata, this->metadata, this->device, this->opts);
698
+ if (!header_file.good()) {
699
+ std::cerr << "Error occurred at writing time." << std::endl;
700
+ throw std::runtime_error("Writing to file error");
701
+ return -1;
702
+ }
703
+
704
+ header_file.close();
705
+ return 0;
706
+ }
685
707
 
686
708
  int GBRL::saveToFile(const std::string& filename){
687
709
  std::ofstream file(filename, std::ios::binary);
@@ -30,6 +30,7 @@ class GBRL {
30
30
  void to_device(deviceType device);
31
31
  std::string get_device();
32
32
  int saveToFile(const std::string& filename);
33
+ int exportModel(const std::string& filename, const std::string& modelname);
33
34
  int loadFromFile(const std::string& filename);
34
35
 
35
36
  void step(const float *obs, const char *categorical_obs, float *grads, const int n_samples, const int n_num_features, const int n_cat_features);
@@ -258,6 +258,10 @@ PYBIND11_MODULE(gbrl_cpp, m) {
258
258
  py::gil_scoped_release release;
259
259
  return self.saveToFile(filename);
260
260
  }, "Save the model to a file");
261
+ gbrl.def("export", [](GBRL &self, const std::string& filename, const std::string& modelname) -> int {
262
+ py::gil_scoped_release release;
263
+ return self.exportModel(filename, modelname);
264
+ }, py::arg("filename"), py::arg("modelname") = "", "Export model as a C-header file");
261
265
  gbrl.def("get_scheduler_lrs", [](GBRL &self) -> std::tuple<float, float> {
262
266
  py::gil_scoped_release release;
263
267
  return self.get_scheduler_lrs();
@@ -6,6 +6,7 @@
6
6
  #include <stdexcept>
7
7
 
8
8
  #include "types.h"
9
+ #include "optimizer.h"
9
10
  #ifdef USE_CUDA
10
11
  #include "cuda_types.h"
11
12
  #endif
@@ -282,6 +283,161 @@ void ensemble_data_dealloc(ensembleData *edata){
282
283
  delete edata;
283
284
  }
284
285
 
286
+ void export_ensemble_data(std::ofstream& header_file, const std::string& model_name, ensembleData *edata, ensembleMetaData *metadata, deviceType device, std::vector<Optimizer*> opts)
287
+ {
288
+ if (!header_file.is_open() || header_file.fail()) {
289
+ std::cerr << "Error file is not open for writing: " << std::endl;
290
+ throw std::runtime_error("Error opening header_file");
291
+ }
292
+ ensembleData *edata_cpu = nullptr;
293
+ #ifdef USE_CUDA
294
+ if (device == gpu){
295
+ edata_cpu = ensemble_data_copy_gpu_cpu(metadata, edata);
296
+ }
297
+ #endif
298
+ if (device == cpu)
299
+ edata_cpu = edata;
300
+
301
+ int binary_splits = 0;
302
+ for (int i = 0; i < metadata->n_trees; ++i){
303
+ binary_splits += edata_cpu->depths[i];
304
+ }
305
+
306
+ for (size_t opt_idx = 0; opt_idx < opts.size(); ++opt_idx){
307
+ optimizerAlgo algo = opts[opt_idx]->getAlgo();
308
+ if (algo != SGD){
309
+ std::cerr << "Error. Can only export SGD optimizers" << std::endl;
310
+ return;
311
+ }
312
+ }
313
+
314
+ header_file << "#ifndef GBRL_MODEL_H\n";
315
+ header_file << "#define GBRL_MODEL_H\n\n";
316
+
317
+
318
+ header_file << "/*\n";
319
+
320
+ if (!model_name.empty()) {
321
+ header_file << "###########################\n";
322
+ header_file << "model_name: " << model_name << "\n";
323
+ }
324
+ header_file << "###########################\n";
325
+ header_file << "n_leaves: " << metadata->n_leaves << ", ";
326
+ header_file << "n_trees: " << metadata->n_trees << ", ";
327
+ header_file << "max_trees: " << metadata->max_trees << ", ";
328
+ header_file << "max_leaves: " << metadata->max_leaves << ", ";
329
+ header_file << "max_trees_batch: " << metadata->max_trees_batch << ", ";
330
+ header_file << "max_leaves_batch: " << metadata->max_leaves_batch << ", ";
331
+ header_file << "output_dim: " << metadata->output_dim << ", ";
332
+ header_file << "policy_dim: " << metadata->policy_dim;
333
+ header_file << "\nmax_depth: " << metadata->max_depth << ", ";
334
+ header_file << "min_data_in_leaf: " << metadata->min_data_in_leaf << ", ";
335
+ header_file << "n_bins: " << metadata->n_bins << ", ";
336
+ header_file << "par_th: " << metadata->par_th << ", ";
337
+ header_file << "cv_beta: " << metadata->cv_beta << ", ";
338
+ header_file << "verbose: " << metadata->verbose << ", ";
339
+ header_file << "batch_size: " << metadata->batch_size << ", ";
340
+ header_file << "use_cv: " << metadata->use_cv;
341
+ header_file << "\nsplit_score_func: " << scoreFuncToString(metadata->split_score_func) << ", ";
342
+ header_file << "generator_type: " << generatorTypeToString(metadata->generator_type) << ", ";
343
+ header_file << "grow_policy: " << growPolicyToString(metadata->grow_policy) << ", ";
344
+ header_file << "n_num_features: " << metadata->n_num_features << ", ";
345
+ header_file << "n_cat_features: " << metadata->n_cat_features << ", ";
346
+ header_file << "iteration: " << metadata->iteration;
347
+ header_file << "\n*/\n";
348
+
349
+ header_file << "#define N_TREES " << metadata->n_trees << "\n";
350
+ header_file << "#define N_LEAVES " << metadata->n_leaves << "\n";
351
+ header_file << "#define BINARY_FEATURES " << binary_splits << "\n";
352
+ header_file << "#define N_OUTPUTS " << metadata->output_dim << "\n";
353
+ header_file << "#define N_FEATURES " << metadata->n_num_features << "\n\n";
354
+
355
+ header_file << "static inline void gbrl_predict(float *results, const float *features){\n\n";
356
+ header_file << "\tunsigned int j, tree_idx, depth, current_depth, idx, leaf_ptr, cond_ptr;\n";
357
+ header_file << "\t/* Model data */\n";
358
+ header_file << "\tconst unsigned int depths[N_TREES] = {";
359
+ for (int i = 0; i < metadata->n_trees; ++i){
360
+ header_file << edata_cpu->depths[i];
361
+ if (i < metadata->n_trees - 1)
362
+ header_file << ", ";
363
+ }
364
+ header_file << "};\n";
365
+ header_file << "\tconst float bias[N_OUTPUTS] = {";
366
+ for (int i = 0; i < metadata->output_dim; ++i){
367
+ header_file << edata_cpu->bias[i];
368
+ if (i < metadata->output_dim - 1)
369
+ header_file << ", ";
370
+ }
371
+ header_file << "};\n";
372
+ header_file << "\tconst unsigned int feature_indices[BINARY_FEATURES] = {";
373
+ for (int i = 0; i < binary_splits; ++i){
374
+ header_file << edata_cpu->feature_indices[i];
375
+ if (i < binary_splits - 1)
376
+ header_file << ", ";
377
+ }
378
+ header_file << "};\n";
379
+ header_file << "\tconst float feature_values[BINARY_FEATURES] = {";
380
+ for (int i = 0; i < binary_splits; ++i){
381
+ header_file << edata_cpu->feature_values[i];
382
+ if (i < binary_splits - 1)
383
+ header_file << ", ";
384
+ }
385
+ header_file << "};\n";
386
+ header_file << "\tconst float leaf_values[N_LEAVES*N_OUTPUTS] = {";
387
+ int tree_idx = 0;
388
+ int limit_leaf_idx = edata_cpu->tree_indices[tree_idx];
389
+ float value;
390
+ for (int i = 0; i < metadata->n_leaves; ++i){
391
+ if (i > limit_leaf_idx){
392
+ tree_idx += 1;
393
+ limit_leaf_idx = edata_cpu->tree_indices[tree_idx];
394
+ }
395
+ int value_idx = i*metadata->output_dim;
396
+ for (size_t opt_idx = 0; opt_idx < opts.size(); ++opt_idx){
397
+ for (int j=opts[opt_idx]->start_idx; j < opts[opt_idx]->end_idx; ++j){
398
+ value = -edata_cpu->values[value_idx + j] * opts[opt_idx]->scheduler->get_lr(tree_idx);
399
+ header_file << value;
400
+ if ((i < metadata->n_leaves - 1) || (j < metadata->output_dim - 1 && i == metadata->n_leaves - 1))
401
+ header_file << ", ";
402
+ }
403
+ }
404
+ }
405
+ header_file << "};\n";
406
+ // header_file << "\tconst unsigned int tree_indices[N_TREES] = {";
407
+ // for (int i = 0; i < metadata->n_trees; ++i){
408
+ // header_file << edata_cpu->tree_indices[i];
409
+ // if (i < metadata->n_trees - 1)
410
+ // header_file << ", ";
411
+ // }
412
+ // header_file << "};\n";
413
+ header_file << "\tleaf_ptr = 0;\n";
414
+ header_file << "\tcond_ptr = 0;\n";
415
+ header_file << "\tunsigned char pass;\n";
416
+ header_file << "\tfor (tree_idx = 0; tree_idx < N_TREES; ++tree_idx)\n";
417
+ header_file << "\t{\n";
418
+ header_file << "\t\tcurrent_depth = depths[tree_idx];\n";
419
+ header_file << "\t\tidx = 0;\n";
420
+ header_file << "\t\tfor (depth = 0; depth < current_depth; ++depth){\n";
421
+ header_file << "\t\t\tpass = (unsigned char)(features[feature_indices[cond_ptr + depth]] > feature_values[cond_ptr + depth]);\n";
422
+ header_file << "\t\t\tidx |= (pass << (current_depth - 1 - depth));\n";
423
+ header_file << "\t\t}\n";
424
+ header_file << "\t\tfor (j = 0 ; j < N_OUTPUTS; j++)\n";
425
+ header_file << "\t\t\tresults[j] += leaf_values[(leaf_ptr + idx)*N_OUTPUTS + j];\n";
426
+ header_file << "\t\tleaf_ptr += (1 << current_depth);\n";
427
+ header_file << "\t\tcond_ptr += current_depth;\n";
428
+ header_file << "\t}\n";
429
+ header_file << "\tfor (j = 0 ; j < N_OUTPUTS; j++)\n";
430
+ header_file << "\t\tresults[j] += bias[j];\n";
431
+ header_file << "}\n";
432
+ header_file << "#endif\n";
433
+
434
+ #ifdef USE_CUDA
435
+ if (device == gpu){
436
+ ensemble_data_dealloc(edata_cpu);
437
+ }
438
+ #endif
439
+ }
440
+
285
441
  void save_ensemble_data(std::ofstream& file, ensembleData *edata, ensembleMetaData *metadata, deviceType device){
286
442
  if (!file.is_open() || file.fail()) {
287
443
  std::cerr << "Error file is not open for writing: " << std::endl;
@@ -338,6 +494,12 @@ void save_ensemble_data(std::ofstream& file, ensembleData *edata, ensembleMetaDa
338
494
  file.write(reinterpret_cast<char*>(&check), sizeof(NULL_CHECK));
339
495
  if (edata_cpu->categorical_values != nullptr)
340
496
  file.write(reinterpret_cast<char*>(edata_cpu->categorical_values), metadata->max_depth * sizes * sizeof(char) * MAX_CHAR_SIZE);
497
+
498
+ #ifdef USE_CUDA
499
+ if (device == gpu){
500
+ ensemble_data_dealloc(edata_cpu);
501
+ }
502
+ #endif
341
503
  }
342
504
 
343
505
  ensembleData* load_ensemble_data(std::ifstream& file, ensembleMetaData *metadata){
@@ -10,6 +10,7 @@
10
10
  #define TREES_BATCH 25000 // 100 K
11
11
  #define MAX_CHAR_SIZE 128
12
12
 
13
+ class Optimizer;
13
14
  struct splitCondition {
14
15
  int feature_idx;
15
16
  float feature_value;
@@ -165,6 +166,7 @@ ensembleData* ensemble_copy_data_alloc(ensembleMetaData *metadata);
165
166
  ensembleData* copy_ensemble_data(ensembleData *other_edata, ensembleMetaData *metadata);
166
167
  void ensemble_data_dealloc(ensembleData *edata);
167
168
  void save_ensemble_data(std::ofstream& file, ensembleData *edata, ensembleMetaData *metadata, deviceType device);
169
+ void export_ensemble_data(std::ofstream& header_file, const std::string& model_name, ensembleData *edata, ensembleMetaData *metadata, deviceType device, std::vector<Optimizer*> opts);
168
170
  ensembleData* load_ensemble_data(std::ifstream& file, ensembleMetaData *metadata);
169
171
  void allocate_ensemble_memory(ensembleMetaData *metadata, ensembleData *edata);
170
172
  #endif
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gbrl
3
- Version: 1.0.0.dev5
3
+ Version: 1.0.0.dev7
4
4
  Summary: Gradient Boosted Trees for RL
5
5
  Author-email: Benjamin Fuhrer <bfuhrer@nvidia.com>, Chen Tessler <ctessler@nvidia.com>, Gal Dalal <galal@nvidia.com>
6
6
  Classifier: Development Status :: 4 - Beta
@@ -33,10 +33,11 @@ GBRL is a Python-based GBT library designed and optimized for reinforcement lear
33
33
  ## Getting started
34
34
 
35
35
  ### Dependencies
36
+ #### MAC OS
37
+ ```
36
38
  llvm
37
39
  openmp
38
-
39
- #### MAC OS
40
+ ```
40
41
 
41
42
  Make sure to run:
42
43
  ```
@@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta"
8
8
 
9
9
  [project]
10
10
  name = "gbrl"
11
- version = "1.0.0.dev5"
11
+ version = "1.0.0.dev7"
12
12
  description = "Gradient Boosted Trees for RL"
13
13
  readme = { file = "README.md", content-type = "text/markdown" }
14
14
  authors = [
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes