wandb 0.1.2 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: fc087b1331eb3548a0c0826543a956c44ac2f301cf67de163a8b98cfbd425dd1
4
- data.tar.gz: 7320addea0abaa3851ae9a1990c078f78fdb3f55d284ab7b7617cbb4e2203d92
3
+ metadata.gz: 310ede46ccba30aa00bfbb81fbb8ffdefd818528c7ae09523e7dca9629258892
4
+ data.tar.gz: 07643a463e807e99db1a4e874bd85b749f6d65b644f15e0423505ead63ec7749
5
5
  SHA512:
6
- metadata.gz: 8225c6929a867bc6d7e3da95698fa4d71f7f87bf2e8b1004601aa9cb7c8214c327570531a292b5f5df5a85f608fb498944ef22093803f7e527c49f87af640f74
7
- data.tar.gz: 22e9685e8b57ea486f19fc91d3253ebed100c708f40aa739155b7eb5ec6c664a3ef6108ead7c933cd5c6e0a3324cb91b1020f3765d7b44143d87ca5c80d72f4e
6
+ metadata.gz: 00e7a6fe5e888d931c1abbc6047207859a46b024c83a05a6e87feeea8d20576f2f3da1d6a5c851c9dd945e7bd564ecfc56b5c4707ed3b54284aea97ecb4dac46
7
+ data.tar.gz: d4b2c592d6a4a5e079024d48cb7a97834a883f1e5e12dbe1c856255d4beed84c90f54f629d818cad14513f4778b08657e6781c958ff977fe607d1a74c0a2deec
data/lib/wandb/version.rb CHANGED
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Wandb
4
- VERSION = "0.1.2"
4
+ VERSION = "0.1.6"
5
5
  end
@@ -1,26 +1,53 @@
1
+ require "xgb"
2
+ require "tempfile"
3
+ require "fileutils"
4
+
1
5
  module Wandb
2
- class XGBoostCallback
6
+ class XGBoostCallback < XGBoost::TrainingCallback
3
7
  MINIMIZE_METRICS = %w[rmse logloss error] # Add other metrics as needed
4
8
  MAXIMIZE_METRICS = %w[auc accuracy] # Add other metrics as needed
5
9
 
6
- def initialize(log_model: false, log_feature_importance: true, importance_type: "gain", define_metric: true)
7
- @log_model = log_model
8
- @log_feature_importance = log_feature_importance
9
- @importance_type = importance_type
10
- @define_metric = define_metric
10
+ class Opts
11
+ attr_accessor :options
11
12
 
12
- return if Wandb.current_run
13
+ def initialize(options = {})
14
+ @options = options
15
+ end
13
16
 
14
- raise "You must call wandb.init() before WandbCallback()"
17
+ def default(key, default)
18
+ options.key?(key) ? options[key] : default
19
+ end
15
20
  end
16
21
 
17
- def before_training(model:)
18
- # Update Wandb config with model configuration
19
- Wandb.current_run.config = model.params
20
- Wandb.log(model.params)
22
+ attr_accessor :project_name, :api_key, :custom_loggers
23
+
24
+ def initialize(options = {})
25
+ options = Opts.new(options)
26
+ @log_model = options.default(:log_model, false)
27
+ @log_feature_importance = options.default(:log_feature_importance, true)
28
+ @importance_type = options.default(:importance_type, "gain")
29
+ @define_metric = options.default(:define_metric, true)
30
+ @api_key = options.default(:api_key, ENV["WANDB_API_KEY"])
31
+ @project_name = options.default(:project_name, nil)
32
+ @custom_loggers = options.default(:custom_loggers, [])
21
33
  end
22
34
 
23
- def after_training(model:)
35
+ def before_training(model)
36
+ Wandb.login(api_key: api_key)
37
+ Wandb.init(project: project_name)
38
+ config = JSON.parse(model.save_config)
39
+ log_conf = {
40
+ learning_rate: config.dig("learner", "gradient_booster", "tree_train_param", "learning_rate").to_f,
41
+ max_depth: config.dig("learner", "gradient_booster", "tree_train_param", "max_depth").to_f,
42
+ n_estimators: model.num_boosted_rounds
43
+ }
44
+ Wandb.current_run.config = log_conf
45
+
46
+ Wandb.log(log_conf)
47
+ model
48
+ end
49
+
50
+ def after_training(model)
24
51
  # Log the model as an artifact
25
52
  log_model_as_artifact(model) if @log_model
26
53
 
@@ -28,54 +55,67 @@ module Wandb
28
55
  log_feature_importance(model) if @log_feature_importance
29
56
 
30
57
  # Log best score and best iteration
31
- return unless model.best_score
58
+ unless model.best_score
59
+ finish
60
+ return model
61
+ end
32
62
 
33
63
  Wandb.log(
34
64
  "best_score" => model.best_score.to_f,
35
65
  "best_iteration" => model.best_iteration.to_i
36
66
  )
67
+ finish
68
+
69
+ model
37
70
  end
38
71
 
39
- def before_iteration(model:, epoch:, evals:)
40
- # noop
72
+ def finish
73
+ Wandb.finish
74
+ FileUtils.rm_rf(File.join(Dir.pwd, "wandb"))
41
75
  end
42
76
 
43
- def after_iteration(model:, epoch:, evals:, res:)
44
- res.each do |metric_name, value|
45
- data, metric = metric_name.split("-", 2)
46
- full_metric_name = "#{data}-#{metric}"
47
-
48
- if @define_metric
49
- define_metric(data, metric)
50
- Wandb.log({ full_metric_name => value })
51
- else
52
- Wandb.log({ full_metric_name => value })
53
- end
54
- end
77
+ def before_iteration(_model, _epoch, _history)
78
+ false
79
+ end
55
80
 
56
- Wandb.log({ "epoch" => epoch })
57
- @define_metric = false
81
+ def after_iteration(model, epoch, history)
82
+ history.each do |split, metric_scores|
83
+ metric = metric_scores.keys.first
84
+ values = metric_scores.values.last
85
+ epoch_value = values[epoch]
86
+
87
+ define_metric(split, metric) if @define_metric && epoch == 0
88
+ full_metric_name = "#{split}-#{metric}"
89
+ Wandb.log({ full_metric_name => epoch_value })
90
+ end
91
+ @custom_loggers.each do |logger|
92
+ logger.call(model, epoch, history)
93
+ end
94
+ Wandb.log("epoch" => epoch)
95
+ false
58
96
  end
59
97
 
60
98
  private
61
99
 
62
100
  def log_model_as_artifact(model)
63
- model_name = "#{Wandb.current_run.id}_model.json"
64
- model_path = File.join(Wandb.current_run.dir, model_name)
65
- model.save_model(model_path)
66
-
67
- model_artifact = Wandb.Artifact(name: model_name, type: "model")
68
- model_artifact.add_file(model_path)
69
- Wandb.current_run.log_artifact(model_artifact)
101
+ Dir.mktmpdir("wandb_xgboost_model") do |tmp_dir|
102
+ model_name = "model.json"
103
+ model_path = File.join(tmp_dir, model_name)
104
+ model.save_model(model_path)
105
+
106
+ model_artifact = Wandb.artifact(name: model_name, type: "model")
107
+ model_artifact.add_file(model_path)
108
+ Wandb.current_run.log_artifact(model_artifact)
109
+ end
70
110
  end
71
111
 
72
112
  def log_feature_importance(model)
73
113
  fi = model.score(importance_type: @importance_type)
74
114
  fi_data = fi.map { |k, v| [k, v] }
75
115
 
76
- table = Wandb.Table(data: fi_data, columns: %w[Feature Importance])
77
- bar_plot = Wandb.plot.bar(table, "Feature", "Importance", title: "Feature Importance")
78
- Wandb.log({ "Feature Importance" => bar_plot })
116
+ table = Wandb::Table.new(data: fi_data, columns: %w[Feature Importance])
117
+ bar_plot = Wandb::Plot.bar(table.table, "Feature", "Importance", title: "Feature Importance")
118
+ Wandb.log({ "Feature Importance" => bar_plot.__pyptr__ })
79
119
  end
80
120
 
81
121
  def define_metric(data, metric_name)
data/lib/wandb.rb CHANGED
@@ -16,10 +16,6 @@ module Wandb
16
16
  @wandb ||= PyCall.import_module("wandb")
17
17
  end
18
18
 
19
- def Table(*args, **kwargs)
20
- __pyptr__.Table.new(*args, **kwargs)
21
- end
22
-
23
19
  def plot(*args, **kwargs)
24
20
  __pyptr__.plot(*args, **kwargs)
25
21
  end
@@ -30,11 +26,12 @@ module Wandb
30
26
  end
31
27
 
32
28
  # Expose wandb.Artifact
33
- def Artifact(*args, **kwargs)
34
- __pyptr__.Artifact.new(*args, **kwargs)
29
+ def artifact(*args, **kwargs)
30
+ py_artifact = __pyptr__.Artifact.new(*args, **kwargs)
31
+ Artifact.new(py_artifact)
35
32
  end
36
33
 
37
- def Error
34
+ def error
38
35
  __pyptr__.Error
39
36
  end
40
37
 
@@ -72,6 +69,10 @@ module Wandb
72
69
  def api
73
70
  @api ||= Api.new(__pyptr__.Api.new)
74
71
  end
72
+
73
+ def plot
74
+ Plot
75
+ end
75
76
  end
76
77
 
77
78
  # Run class
@@ -80,6 +81,10 @@ module Wandb
80
81
  @run = run
81
82
  end
82
83
 
84
+ def run_id
85
+ @run.run_id
86
+ end
87
+
83
88
  def log(metrics = {})
84
89
  metrics.symbolize_keys!
85
90
  @run.log(metrics, {})
@@ -104,6 +109,53 @@ module Wandb
104
109
  def config=(new_config)
105
110
  @run.config.update(PyCall::Dict.new(new_config))
106
111
  end
112
+
113
+ def log_artifact(artifact)
114
+ @run.log_artifact(artifact.__pyptr__)
115
+ end
116
+ end
117
+
118
+ # Artifact class
119
+ class Artifact
120
+ def initialize(artifact)
121
+ @artifact = artifact
122
+ end
123
+
124
+ def __pyptr__
125
+ @artifact
126
+ end
127
+
128
+ def name
129
+ @artifact.name
130
+ end
131
+
132
+ def type
133
+ @artifact.type
134
+ end
135
+
136
+ def add_file(local_path, name = nil)
137
+ @artifact.add_file(local_path, name)
138
+ end
139
+
140
+ def add_dir(local_dir, name = nil)
141
+ @artifact.add_dir(local_dir, name)
142
+ end
143
+
144
+ def get_path(name)
145
+ @artifact.get_path(name)
146
+ end
147
+
148
+ def metadata
149
+ @artifact.metadata
150
+ end
151
+
152
+ def metadata=(new_metadata)
153
+ @artifact.metadata = new_metadata
154
+ end
155
+
156
+ def save
157
+ @artifact.save
158
+ end
107
159
  end
108
160
 
109
161
  # Api class
@@ -137,6 +189,78 @@ module Wandb
137
189
  @project.description
138
190
  end
139
191
  end
192
+
193
+ # Table class
194
+ class Table
195
+ attr_accessor :table, :data, :columns
196
+
197
+ def initialize(data: {}, columns: [])
198
+ @table = Wandb.__pyptr__.Table.new(data: data, columns: columns)
199
+ @data = data
200
+ @columns = columns
201
+ end
202
+
203
+ def __pyptr__
204
+ @table
205
+ end
206
+
207
+ def add_data(*args)
208
+ @table.add_data(*args)
209
+ end
210
+
211
+ def add_column(name, data)
212
+ @table.add_column(name, data)
213
+ end
214
+
215
+ def get_column(name)
216
+ @table.get_column(name)
217
+ end
218
+
219
+ def columns
220
+ @table.columns
221
+ end
222
+
223
+ def data
224
+ @table.data
225
+ end
226
+
227
+ def to_pandas
228
+ @table.get_dataframe
229
+ end
230
+ end
231
+
232
+ # Plot class
233
+ class Plot
234
+ class << self
235
+ def bar(table, x_key, y_key, title: nil)
236
+ py_plot = Wandb.__pyptr__.plot.bar(table.__pyptr__, x_key, y_key, title: title)
237
+ new(py_plot)
238
+ end
239
+
240
+ def line(table, x_key, y_key, title: nil)
241
+ py_plot = Wandb.__pyptr__.plot.line(table.__pyptr__, x_key, y_key, title: title)
242
+ new(py_plot)
243
+ end
244
+
245
+ def scatter(table, x_key, y_key, title: nil)
246
+ py_plot = Wandb.__pyptr__.plot.scatter(table.__pyptr__, x_key, y_key, title: title)
247
+ new(py_plot)
248
+ end
249
+
250
+ def histogram(table, value_key, title: nil)
251
+ py_plot = Wandb.__pyptr__.plot.histogram(table.__pyptr__, value_key, title: title)
252
+ new(py_plot)
253
+ end
254
+ end
255
+
256
+ def initialize(plot)
257
+ @plot = plot
258
+ end
259
+
260
+ def __pyptr__
261
+ @plot
262
+ end
263
+ end
140
264
  end
141
265
 
142
266
  require_relative "wandb/xgboost_callback"
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: wandb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.2
4
+ version: 0.1.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Brett Shollenberger
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-10-11 00:00:00.000000000 Z
11
+ date: 2024-10-18 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: pycall