wandb 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 4a00bf9591759e1339bca2ff33aca8dc245f71e8ef8f9e3aa4cd70a59f598a10
4
- data.tar.gz: 880bed08a72bb3f03d48d093ab68a94ab8b706b759ec10945008979fd0c044c8
3
+ metadata.gz: 6f0f0909ffdfe53479bf0c6a09f09f96c7e0aec1dc846f16d9b42c1b1e7f4fae
4
+ data.tar.gz: f1ffc1d2cdabc05964ef0187f5ab168a71ef38077591e3a07bf960da3b5ae8b5
5
5
  SHA512:
6
- metadata.gz: b6148f71160761a123696f0beac8b2ebe68064a4f64bd55ce1e9ec417c44b5c032a9f60b77394a3ebd0f0535e60c2ce50e53b5606f8acb958bfd35148c0f4a53
7
- data.tar.gz: 14a72f4b7b0bc05234d9e0e276c724b1077ac61c07b755f90ce365d4d596c9183928b5df6d9396391a9339d6441b92e1e4c0733f67ee19634ba7fec47acb6818
6
+ metadata.gz: 93b0f7df5f4bb660a9d566b3492a75690336eeddbfbe0aad15fc100c292499f5fb93a543f1f9f72b6120b2682d8ee25dd3b5ead6906500efd1f5ccc29f6a7219
7
+ data.tar.gz: 95e7472bfa9e0322bb353d6ee77bd76a7b6f1c09edd6ef23f10f294ef982b35bae279b5ccbb4beee44991decf774b3bac45ab708ebbab4dbb0dda47917715545
data/lib/wandb/version.rb CHANGED
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Wandb
4
- VERSION = "0.1.5"
4
+ VERSION = "0.1.7"
5
5
  end
@@ -1,6 +1,7 @@
1
1
  require "xgb"
2
2
  require "tempfile"
3
3
  require "fileutils"
4
+ require "xgboost/training_callback"
4
5
 
5
6
  module Wandb
6
7
  class XGBoostCallback < XGBoost::TrainingCallback
@@ -19,7 +20,7 @@ module Wandb
19
20
  end
20
21
  end
21
22
 
22
- attr_accessor :project_name, :api_key
23
+ attr_accessor :project_name, :api_key, :custom_loggers
23
24
 
24
25
  def initialize(options = {})
25
26
  options = Opts.new(options)
@@ -29,22 +30,21 @@ module Wandb
29
30
  @define_metric = options.default(:define_metric, true)
30
31
  @api_key = options.default(:api_key, ENV["WANDB_API_KEY"])
31
32
  @project_name = options.default(:project_name, nil)
32
-
33
- raise "WANDB_API_KEY required" unless api_key
34
- raise "project_name required" unless project_name
35
-
36
- Wandb.login(api_key: api_key)
37
- Wandb.init(project: project_name)
38
-
39
- return if Wandb.current_run
40
-
41
- raise "You must call wandb.init() before WandbCallback()"
33
+ @custom_loggers = options.default(:custom_loggers, [])
42
34
  end
43
35
 
44
36
  def before_training(model)
45
- # Update Wandb config with model configuration
46
- Wandb.current_run.config = model.params
47
- Wandb.log(model.params)
37
+ Wandb.login(api_key: api_key)
38
+ Wandb.init(project: project_name)
39
+ config = JSON.parse(model.save_config)
40
+ log_conf = {
41
+ learning_rate: config.dig("learner", "gradient_booster", "tree_train_param", "learning_rate").to_f,
42
+ max_depth: config.dig("learner", "gradient_booster", "tree_train_param", "max_depth").to_f,
43
+ n_estimators: model.num_boosted_rounds
44
+ }
45
+ Wandb.current_run.config = log_conf
46
+
47
+ Wandb.log(log_conf)
48
48
  model
49
49
  end
50
50
 
@@ -57,7 +57,7 @@ module Wandb
57
57
 
58
58
  # Log best score and best iteration
59
59
  unless model.best_score
60
- Wandb.finish
60
+ finish
61
61
  return model
62
62
  end
63
63
 
@@ -65,18 +65,21 @@ module Wandb
65
65
  "best_score" => model.best_score.to_f,
66
66
  "best_iteration" => model.best_iteration.to_i
67
67
  )
68
-
69
- Wandb.finish
70
- FileUtils.rm_rf(Rails.root.join("wandb"))
68
+ finish
71
69
 
72
70
  model
73
71
  end
74
72
 
73
+ def finish
74
+ Wandb.finish
75
+ FileUtils.rm_rf(File.join(Dir.pwd, "wandb"))
76
+ end
77
+
75
78
  def before_iteration(_model, _epoch, _history)
76
79
  false
77
80
  end
78
81
 
79
- def after_iteration(_model, epoch, history)
82
+ def after_iteration(model, epoch, history)
80
83
  history.each do |split, metric_scores|
81
84
  metric = metric_scores.keys.first
82
85
  values = metric_scores.values.last
@@ -86,6 +89,9 @@ module Wandb
86
89
  full_metric_name = "#{split}-#{metric}"
87
90
  Wandb.log({ full_metric_name => epoch_value })
88
91
  end
92
+ @custom_loggers.each do |logger|
93
+ logger.call(model, epoch, history)
94
+ end
89
95
  Wandb.log("epoch" => epoch)
90
96
  false
91
97
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: wandb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.5
4
+ version: 0.1.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - Brett Shollenberger
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-10-16 00:00:00.000000000 Z
11
+ date: 2024-10-18 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: pycall