wandb 0.1.2 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: fc087b1331eb3548a0c0826543a956c44ac2f301cf67de163a8b98cfbd425dd1
4
- data.tar.gz: 7320addea0abaa3851ae9a1990c078f78fdb3f55d284ab7b7617cbb4e2203d92
3
+ metadata.gz: 4a00bf9591759e1339bca2ff33aca8dc245f71e8ef8f9e3aa4cd70a59f598a10
4
+ data.tar.gz: 880bed08a72bb3f03d48d093ab68a94ab8b706b759ec10945008979fd0c044c8
5
5
  SHA512:
6
- metadata.gz: 8225c6929a867bc6d7e3da95698fa4d71f7f87bf2e8b1004601aa9cb7c8214c327570531a292b5f5df5a85f608fb498944ef22093803f7e527c49f87af640f74
7
- data.tar.gz: 22e9685e8b57ea486f19fc91d3253ebed100c708f40aa739155b7eb5ec6c664a3ef6108ead7c933cd5c6e0a3324cb91b1020f3765d7b44143d87ca5c80d72f4e
6
+ metadata.gz: b6148f71160761a123696f0beac8b2ebe68064a4f64bd55ce1e9ec417c44b5c032a9f60b77394a3ebd0f0535e60c2ce50e53b5606f8acb958bfd35148c0f4a53
7
+ data.tar.gz: 14a72f4b7b0bc05234d9e0e276c724b1077ac61c07b755f90ce365d4d596c9183928b5df6d9396391a9339d6441b92e1e4c0733f67ee19634ba7fec47acb6818
data/lib/wandb/version.rb CHANGED
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Wandb
4
- VERSION = "0.1.2"
4
+ VERSION = "0.1.5"
5
5
  end
@@ -1,26 +1,54 @@
1
+ require "xgb"
2
+ require "tempfile"
3
+ require "fileutils"
4
+
1
5
  module Wandb
2
- class XGBoostCallback
6
+ class XGBoostCallback < XGBoost::TrainingCallback
3
7
  MINIMIZE_METRICS = %w[rmse logloss error] # Add other metrics as needed
4
8
  MAXIMIZE_METRICS = %w[auc accuracy] # Add other metrics as needed
5
9
 
6
- def initialize(log_model: false, log_feature_importance: true, importance_type: "gain", define_metric: true)
7
- @log_model = log_model
8
- @log_feature_importance = log_feature_importance
9
- @importance_type = importance_type
10
- @define_metric = define_metric
10
+ class Opts
11
+ attr_accessor :options
12
+
13
+ def initialize(options = {})
14
+ @options = options
15
+ end
16
+
17
+ def default(key, default)
18
+ options.key?(key) ? options[key] : default
19
+ end
20
+ end
21
+
22
+ attr_accessor :project_name, :api_key
23
+
24
+ def initialize(options = {})
25
+ options = Opts.new(options)
26
+ @log_model = options.default(:log_model, false)
27
+ @log_feature_importance = options.default(:log_feature_importance, true)
28
+ @importance_type = options.default(:importance_type, "gain")
29
+ @define_metric = options.default(:define_metric, true)
30
+ @api_key = options.default(:api_key, ENV["WANDB_API_KEY"])
31
+ @project_name = options.default(:project_name, nil)
32
+
33
+ raise "WANDB_API_KEY required" unless api_key
34
+ raise "project_name required" unless project_name
35
+
36
+ Wandb.login(api_key: api_key)
37
+ Wandb.init(project: project_name)
11
38
 
12
39
  return if Wandb.current_run
13
40
 
14
41
  raise "You must call wandb.init() before WandbCallback()"
15
42
  end
16
43
 
17
- def before_training(model:)
44
+ def before_training(model)
18
45
  # Update Wandb config with model configuration
19
46
  Wandb.current_run.config = model.params
20
47
  Wandb.log(model.params)
48
+ model
21
49
  end
22
50
 
23
- def after_training(model:)
51
+ def after_training(model)
24
52
  # Log the model as an artifact
25
53
  log_model_as_artifact(model) if @log_model
26
54
 
@@ -28,54 +56,61 @@ module Wandb
28
56
  log_feature_importance(model) if @log_feature_importance
29
57
 
30
58
  # Log best score and best iteration
31
- return unless model.best_score
59
+ unless model.best_score
60
+ Wandb.finish
61
+ return model
62
+ end
32
63
 
33
64
  Wandb.log(
34
65
  "best_score" => model.best_score.to_f,
35
66
  "best_iteration" => model.best_iteration.to_i
36
67
  )
68
+
69
+ Wandb.finish
70
+ FileUtils.rm_rf(Rails.root.join("wandb"))
71
+
72
+ model
37
73
  end
38
74
 
39
- def before_iteration(model:, epoch:, evals:)
40
- # noop
75
+ def before_iteration(_model, _epoch, _history)
76
+ false
41
77
  end
42
78
 
43
- def after_iteration(model:, epoch:, evals:, res:)
44
- res.each do |metric_name, value|
45
- data, metric = metric_name.split("-", 2)
46
- full_metric_name = "#{data}-#{metric}"
47
-
48
- if @define_metric
49
- define_metric(data, metric)
50
- Wandb.log({ full_metric_name => value })
51
- else
52
- Wandb.log({ full_metric_name => value })
53
- end
54
- end
79
+ def after_iteration(_model, epoch, history)
80
+ history.each do |split, metric_scores|
81
+ metric = metric_scores.keys.first
82
+ values = metric_scores.values.last
83
+ epoch_value = values[epoch]
55
84
 
56
- Wandb.log({ "epoch" => epoch })
57
- @define_metric = false
85
+ define_metric(split, metric) if @define_metric && epoch == 0
86
+ full_metric_name = "#{split}-#{metric}"
87
+ Wandb.log({ full_metric_name => epoch_value })
88
+ end
89
+ Wandb.log("epoch" => epoch)
90
+ false
58
91
  end
59
92
 
60
93
  private
61
94
 
62
95
  def log_model_as_artifact(model)
63
- model_name = "#{Wandb.current_run.id}_model.json"
64
- model_path = File.join(Wandb.current_run.dir, model_name)
65
- model.save_model(model_path)
66
-
67
- model_artifact = Wandb.Artifact(name: model_name, type: "model")
68
- model_artifact.add_file(model_path)
69
- Wandb.current_run.log_artifact(model_artifact)
96
+ Dir.mktmpdir("wandb_xgboost_model") do |tmp_dir|
97
+ model_name = "model.json"
98
+ model_path = File.join(tmp_dir, model_name)
99
+ model.save_model(model_path)
100
+
101
+ model_artifact = Wandb.artifact(name: model_name, type: "model")
102
+ model_artifact.add_file(model_path)
103
+ Wandb.current_run.log_artifact(model_artifact)
104
+ end
70
105
  end
71
106
 
72
107
  def log_feature_importance(model)
73
108
  fi = model.score(importance_type: @importance_type)
74
109
  fi_data = fi.map { |k, v| [k, v] }
75
110
 
76
- table = Wandb.Table(data: fi_data, columns: %w[Feature Importance])
77
- bar_plot = Wandb.plot.bar(table, "Feature", "Importance", title: "Feature Importance")
78
- Wandb.log({ "Feature Importance" => bar_plot })
111
+ table = Wandb::Table.new(data: fi_data, columns: %w[Feature Importance])
112
+ bar_plot = Wandb::Plot.bar(table.table, "Feature", "Importance", title: "Feature Importance")
113
+ Wandb.log({ "Feature Importance" => bar_plot.__pyptr__ })
79
114
  end
80
115
 
81
116
  def define_metric(data, metric_name)
data/lib/wandb.rb CHANGED
@@ -16,10 +16,6 @@ module Wandb
16
16
  @wandb ||= PyCall.import_module("wandb")
17
17
  end
18
18
 
19
- def Table(*args, **kwargs)
20
- __pyptr__.Table.new(*args, **kwargs)
21
- end
22
-
23
19
  def plot(*args, **kwargs)
24
20
  __pyptr__.plot(*args, **kwargs)
25
21
  end
@@ -30,11 +26,12 @@ module Wandb
30
26
  end
31
27
 
32
28
  # Expose wandb.Artifact
33
- def Artifact(*args, **kwargs)
34
- __pyptr__.Artifact.new(*args, **kwargs)
29
+ def artifact(*args, **kwargs)
30
+ py_artifact = __pyptr__.Artifact.new(*args, **kwargs)
31
+ Artifact.new(py_artifact)
35
32
  end
36
33
 
37
- def Error
34
+ def error
38
35
  __pyptr__.Error
39
36
  end
40
37
 
@@ -72,6 +69,10 @@ module Wandb
72
69
  def api
73
70
  @api ||= Api.new(__pyptr__.Api.new)
74
71
  end
72
+
73
+ def plot
74
+ Plot
75
+ end
75
76
  end
76
77
 
77
78
  # Run class
@@ -80,6 +81,10 @@ module Wandb
80
81
  @run = run
81
82
  end
82
83
 
84
+ def run_id
85
+ @run.run_id
86
+ end
87
+
83
88
  def log(metrics = {})
84
89
  metrics.symbolize_keys!
85
90
  @run.log(metrics, {})
@@ -104,6 +109,53 @@ module Wandb
104
109
  def config=(new_config)
105
110
  @run.config.update(PyCall::Dict.new(new_config))
106
111
  end
112
+
113
+ def log_artifact(artifact)
114
+ @run.log_artifact(artifact.__pyptr__)
115
+ end
116
+ end
117
+
118
+ # Artifact class
119
+ class Artifact
120
+ def initialize(artifact)
121
+ @artifact = artifact
122
+ end
123
+
124
+ def __pyptr__
125
+ @artifact
126
+ end
127
+
128
+ def name
129
+ @artifact.name
130
+ end
131
+
132
+ def type
133
+ @artifact.type
134
+ end
135
+
136
+ def add_file(local_path, name = nil)
137
+ @artifact.add_file(local_path, name)
138
+ end
139
+
140
+ def add_dir(local_dir, name = nil)
141
+ @artifact.add_dir(local_dir, name)
142
+ end
143
+
144
+ def get_path(name)
145
+ @artifact.get_path(name)
146
+ end
147
+
148
+ def metadata
149
+ @artifact.metadata
150
+ end
151
+
152
+ def metadata=(new_metadata)
153
+ @artifact.metadata = new_metadata
154
+ end
155
+
156
+ def save
157
+ @artifact.save
158
+ end
107
159
  end
108
160
 
109
161
  # Api class
@@ -137,6 +189,78 @@ module Wandb
137
189
  @project.description
138
190
  end
139
191
  end
192
+
193
+ # Table class
194
+ class Table
195
+ attr_accessor :table, :data, :columns
196
+
197
+ def initialize(data: {}, columns: [])
198
+ @table = Wandb.__pyptr__.Table.new(data: data, columns: columns)
199
+ @data = data
200
+ @columns = columns
201
+ end
202
+
203
+ def __pyptr__
204
+ @table
205
+ end
206
+
207
+ def add_data(*args)
208
+ @table.add_data(*args)
209
+ end
210
+
211
+ def add_column(name, data)
212
+ @table.add_column(name, data)
213
+ end
214
+
215
+ def get_column(name)
216
+ @table.get_column(name)
217
+ end
218
+
219
+ def columns
220
+ @table.columns
221
+ end
222
+
223
+ def data
224
+ @table.data
225
+ end
226
+
227
+ def to_pandas
228
+ @table.get_dataframe
229
+ end
230
+ end
231
+
232
+ # Plot class
233
+ class Plot
234
+ class << self
235
+ def bar(table, x_key, y_key, title: nil)
236
+ py_plot = Wandb.__pyptr__.plot.bar(table.__pyptr__, x_key, y_key, title: title)
237
+ new(py_plot)
238
+ end
239
+
240
+ def line(table, x_key, y_key, title: nil)
241
+ py_plot = Wandb.__pyptr__.plot.line(table.__pyptr__, x_key, y_key, title: title)
242
+ new(py_plot)
243
+ end
244
+
245
+ def scatter(table, x_key, y_key, title: nil)
246
+ py_plot = Wandb.__pyptr__.plot.scatter(table.__pyptr__, x_key, y_key, title: title)
247
+ new(py_plot)
248
+ end
249
+
250
+ def histogram(table, value_key, title: nil)
251
+ py_plot = Wandb.__pyptr__.plot.histogram(table.__pyptr__, value_key, title: title)
252
+ new(py_plot)
253
+ end
254
+ end
255
+
256
+ def initialize(plot)
257
+ @plot = plot
258
+ end
259
+
260
+ def __pyptr__
261
+ @plot
262
+ end
263
+ end
140
264
  end
141
265
 
142
266
  require_relative "wandb/xgboost_callback"
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: wandb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.2
4
+ version: 0.1.5
5
5
  platform: ruby
6
6
  authors:
7
7
  - Brett Shollenberger
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-10-11 00:00:00.000000000 Z
11
+ date: 2024-10-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: pycall