wandb 0.1.1 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/wandb/version.rb +1 -1
- data/lib/wandb/xgboost_callback.rb +70 -35
- data/lib/wandb.rb +135 -9
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 4a00bf9591759e1339bca2ff33aca8dc245f71e8ef8f9e3aa4cd70a59f598a10
|
4
|
+
data.tar.gz: 880bed08a72bb3f03d48d093ab68a94ab8b706b759ec10945008979fd0c044c8
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: b6148f71160761a123696f0beac8b2ebe68064a4f64bd55ce1e9ec417c44b5c032a9f60b77394a3ebd0f0535e60c2ce50e53b5606f8acb958bfd35148c0f4a53
|
7
|
+
data.tar.gz: 14a72f4b7b0bc05234d9e0e276c724b1077ac61c07b755f90ce365d4d596c9183928b5df6d9396391a9339d6441b92e1e4c0733f67ee19634ba7fec47acb6818
|
data/lib/wandb/version.rb
CHANGED
@@ -1,26 +1,54 @@
|
|
1
|
+
require "xgb"
|
2
|
+
require "tempfile"
|
3
|
+
require "fileutils"
|
4
|
+
|
1
5
|
module Wandb
|
2
|
-
class XGBoostCallback
|
6
|
+
class XGBoostCallback < XGBoost::TrainingCallback
|
3
7
|
MINIMIZE_METRICS = %w[rmse logloss error] # Add other metrics as needed
|
4
8
|
MAXIMIZE_METRICS = %w[auc accuracy] # Add other metrics as needed
|
5
9
|
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
10
|
+
class Opts
|
11
|
+
attr_accessor :options
|
12
|
+
|
13
|
+
def initialize(options = {})
|
14
|
+
@options = options
|
15
|
+
end
|
16
|
+
|
17
|
+
def default(key, default)
|
18
|
+
options.key?(key) ? options[key] : default
|
19
|
+
end
|
20
|
+
end
|
21
|
+
|
22
|
+
attr_accessor :project_name, :api_key
|
23
|
+
|
24
|
+
def initialize(options = {})
|
25
|
+
options = Opts.new(options)
|
26
|
+
@log_model = options.default(:log_model, false)
|
27
|
+
@log_feature_importance = options.default(:log_feature_importance, true)
|
28
|
+
@importance_type = options.default(:importance_type, "gain")
|
29
|
+
@define_metric = options.default(:define_metric, true)
|
30
|
+
@api_key = options.default(:api_key, ENV["WANDB_API_KEY"])
|
31
|
+
@project_name = options.default(:project_name, nil)
|
32
|
+
|
33
|
+
raise "WANDB_API_KEY required" unless api_key
|
34
|
+
raise "project_name required" unless project_name
|
35
|
+
|
36
|
+
Wandb.login(api_key: api_key)
|
37
|
+
Wandb.init(project: project_name)
|
11
38
|
|
12
39
|
return if Wandb.current_run
|
13
40
|
|
14
41
|
raise "You must call wandb.init() before WandbCallback()"
|
15
42
|
end
|
16
43
|
|
17
|
-
def before_training(model
|
44
|
+
def before_training(model)
|
18
45
|
# Update Wandb config with model configuration
|
19
46
|
Wandb.current_run.config = model.params
|
20
47
|
Wandb.log(model.params)
|
48
|
+
model
|
21
49
|
end
|
22
50
|
|
23
|
-
def after_training(model
|
51
|
+
def after_training(model)
|
24
52
|
# Log the model as an artifact
|
25
53
|
log_model_as_artifact(model) if @log_model
|
26
54
|
|
@@ -28,54 +56,61 @@ module Wandb
|
|
28
56
|
log_feature_importance(model) if @log_feature_importance
|
29
57
|
|
30
58
|
# Log best score and best iteration
|
31
|
-
|
59
|
+
unless model.best_score
|
60
|
+
Wandb.finish
|
61
|
+
return model
|
62
|
+
end
|
32
63
|
|
33
64
|
Wandb.log(
|
34
65
|
"best_score" => model.best_score.to_f,
|
35
66
|
"best_iteration" => model.best_iteration.to_i
|
36
67
|
)
|
68
|
+
|
69
|
+
Wandb.finish
|
70
|
+
FileUtils.rm_rf(Rails.root.join("wandb"))
|
71
|
+
|
72
|
+
model
|
37
73
|
end
|
38
74
|
|
39
|
-
def before_iteration(
|
40
|
-
|
75
|
+
def before_iteration(_model, _epoch, _history)
|
76
|
+
false
|
41
77
|
end
|
42
78
|
|
43
|
-
def after_iteration(
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
if @define_metric
|
49
|
-
define_metric(data, metric)
|
50
|
-
Wandb.log({ full_metric_name => value })
|
51
|
-
else
|
52
|
-
Wandb.log({ full_metric_name => value })
|
53
|
-
end
|
54
|
-
end
|
79
|
+
def after_iteration(_model, epoch, history)
|
80
|
+
history.each do |split, metric_scores|
|
81
|
+
metric = metric_scores.keys.first
|
82
|
+
values = metric_scores.values.last
|
83
|
+
epoch_value = values[epoch]
|
55
84
|
|
56
|
-
|
57
|
-
|
85
|
+
define_metric(split, metric) if @define_metric && epoch == 0
|
86
|
+
full_metric_name = "#{split}-#{metric}"
|
87
|
+
Wandb.log({ full_metric_name => epoch_value })
|
88
|
+
end
|
89
|
+
Wandb.log("epoch" => epoch)
|
90
|
+
false
|
58
91
|
end
|
59
92
|
|
60
93
|
private
|
61
94
|
|
62
95
|
def log_model_as_artifact(model)
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
96
|
+
Dir.mktmpdir("wandb_xgboost_model") do |tmp_dir|
|
97
|
+
model_name = "model.json"
|
98
|
+
model_path = File.join(tmp_dir, model_name)
|
99
|
+
model.save_model(model_path)
|
100
|
+
|
101
|
+
model_artifact = Wandb.artifact(name: model_name, type: "model")
|
102
|
+
model_artifact.add_file(model_path)
|
103
|
+
Wandb.current_run.log_artifact(model_artifact)
|
104
|
+
end
|
70
105
|
end
|
71
106
|
|
72
107
|
def log_feature_importance(model)
|
73
108
|
fi = model.score(importance_type: @importance_type)
|
74
109
|
fi_data = fi.map { |k, v| [k, v] }
|
75
110
|
|
76
|
-
table = Wandb.
|
77
|
-
bar_plot = Wandb.
|
78
|
-
Wandb.log({ "Feature Importance" => bar_plot })
|
111
|
+
table = Wandb::Table.new(data: fi_data, columns: %w[Feature Importance])
|
112
|
+
bar_plot = Wandb::Plot.bar(table.table, "Feature", "Importance", title: "Feature Importance")
|
113
|
+
Wandb.log({ "Feature Importance" => bar_plot.__pyptr__ })
|
79
114
|
end
|
80
115
|
|
81
116
|
def define_metric(data, metric_name)
|
data/lib/wandb.rb
CHANGED
@@ -16,25 +16,24 @@ module Wandb
|
|
16
16
|
@wandb ||= PyCall.import_module("wandb")
|
17
17
|
end
|
18
18
|
|
19
|
-
def
|
20
|
-
__pyptr__.
|
19
|
+
def plot(*args, **kwargs)
|
20
|
+
__pyptr__.plot(*args, **kwargs)
|
21
21
|
end
|
22
22
|
|
23
|
-
# Expose wandb.plot
|
24
|
-
delegate :plot, to: :__pyptr__
|
25
|
-
|
26
23
|
# Expose define_metric
|
27
24
|
def define_metric(metric_name, **kwargs)
|
28
25
|
__pyptr__.define_metric(name: metric_name.force_encoding("UTF-8"), **kwargs)
|
29
26
|
end
|
30
27
|
|
31
28
|
# Expose wandb.Artifact
|
32
|
-
def
|
33
|
-
__pyptr__.Artifact.new(*args, **kwargs)
|
29
|
+
def artifact(*args, **kwargs)
|
30
|
+
py_artifact = __pyptr__.Artifact.new(*args, **kwargs)
|
31
|
+
Artifact.new(py_artifact)
|
34
32
|
end
|
35
33
|
|
36
|
-
|
37
|
-
|
34
|
+
def error
|
35
|
+
__pyptr__.Error
|
36
|
+
end
|
38
37
|
|
39
38
|
# Login to Wandb
|
40
39
|
def login(api_key: nil, **kwargs)
|
@@ -70,6 +69,10 @@ module Wandb
|
|
70
69
|
def api
|
71
70
|
@api ||= Api.new(__pyptr__.Api.new)
|
72
71
|
end
|
72
|
+
|
73
|
+
def plot
|
74
|
+
Plot
|
75
|
+
end
|
73
76
|
end
|
74
77
|
|
75
78
|
# Run class
|
@@ -78,6 +81,10 @@ module Wandb
|
|
78
81
|
@run = run
|
79
82
|
end
|
80
83
|
|
84
|
+
def run_id
|
85
|
+
@run.run_id
|
86
|
+
end
|
87
|
+
|
81
88
|
def log(metrics = {})
|
82
89
|
metrics.symbolize_keys!
|
83
90
|
@run.log(metrics, {})
|
@@ -102,6 +109,53 @@ module Wandb
|
|
102
109
|
def config=(new_config)
|
103
110
|
@run.config.update(PyCall::Dict.new(new_config))
|
104
111
|
end
|
112
|
+
|
113
|
+
def log_artifact(artifact)
|
114
|
+
@run.log_artifact(artifact.__pyptr__)
|
115
|
+
end
|
116
|
+
end
|
117
|
+
|
118
|
+
# Artifact class
|
119
|
+
class Artifact
|
120
|
+
def initialize(artifact)
|
121
|
+
@artifact = artifact
|
122
|
+
end
|
123
|
+
|
124
|
+
def __pyptr__
|
125
|
+
@artifact
|
126
|
+
end
|
127
|
+
|
128
|
+
def name
|
129
|
+
@artifact.name
|
130
|
+
end
|
131
|
+
|
132
|
+
def type
|
133
|
+
@artifact.type
|
134
|
+
end
|
135
|
+
|
136
|
+
def add_file(local_path, name = nil)
|
137
|
+
@artifact.add_file(local_path, name)
|
138
|
+
end
|
139
|
+
|
140
|
+
def add_dir(local_dir, name = nil)
|
141
|
+
@artifact.add_dir(local_dir, name)
|
142
|
+
end
|
143
|
+
|
144
|
+
def get_path(name)
|
145
|
+
@artifact.get_path(name)
|
146
|
+
end
|
147
|
+
|
148
|
+
def metadata
|
149
|
+
@artifact.metadata
|
150
|
+
end
|
151
|
+
|
152
|
+
def metadata=(new_metadata)
|
153
|
+
@artifact.metadata = new_metadata
|
154
|
+
end
|
155
|
+
|
156
|
+
def save
|
157
|
+
@artifact.save
|
158
|
+
end
|
105
159
|
end
|
106
160
|
|
107
161
|
# Api class
|
@@ -135,6 +189,78 @@ module Wandb
|
|
135
189
|
@project.description
|
136
190
|
end
|
137
191
|
end
|
192
|
+
|
193
|
+
# Table class
|
194
|
+
class Table
|
195
|
+
attr_accessor :table, :data, :columns
|
196
|
+
|
197
|
+
def initialize(data: {}, columns: [])
|
198
|
+
@table = Wandb.__pyptr__.Table.new(data: data, columns: columns)
|
199
|
+
@data = data
|
200
|
+
@columns = columns
|
201
|
+
end
|
202
|
+
|
203
|
+
def __pyptr__
|
204
|
+
@table
|
205
|
+
end
|
206
|
+
|
207
|
+
def add_data(*args)
|
208
|
+
@table.add_data(*args)
|
209
|
+
end
|
210
|
+
|
211
|
+
def add_column(name, data)
|
212
|
+
@table.add_column(name, data)
|
213
|
+
end
|
214
|
+
|
215
|
+
def get_column(name)
|
216
|
+
@table.get_column(name)
|
217
|
+
end
|
218
|
+
|
219
|
+
def columns
|
220
|
+
@table.columns
|
221
|
+
end
|
222
|
+
|
223
|
+
def data
|
224
|
+
@table.data
|
225
|
+
end
|
226
|
+
|
227
|
+
def to_pandas
|
228
|
+
@table.get_dataframe
|
229
|
+
end
|
230
|
+
end
|
231
|
+
|
232
|
+
# Plot class
|
233
|
+
class Plot
|
234
|
+
class << self
|
235
|
+
def bar(table, x_key, y_key, title: nil)
|
236
|
+
py_plot = Wandb.__pyptr__.plot.bar(table.__pyptr__, x_key, y_key, title: title)
|
237
|
+
new(py_plot)
|
238
|
+
end
|
239
|
+
|
240
|
+
def line(table, x_key, y_key, title: nil)
|
241
|
+
py_plot = Wandb.__pyptr__.plot.line(table.__pyptr__, x_key, y_key, title: title)
|
242
|
+
new(py_plot)
|
243
|
+
end
|
244
|
+
|
245
|
+
def scatter(table, x_key, y_key, title: nil)
|
246
|
+
py_plot = Wandb.__pyptr__.plot.scatter(table.__pyptr__, x_key, y_key, title: title)
|
247
|
+
new(py_plot)
|
248
|
+
end
|
249
|
+
|
250
|
+
def histogram(table, value_key, title: nil)
|
251
|
+
py_plot = Wandb.__pyptr__.plot.histogram(table.__pyptr__, value_key, title: title)
|
252
|
+
new(py_plot)
|
253
|
+
end
|
254
|
+
end
|
255
|
+
|
256
|
+
def initialize(plot)
|
257
|
+
@plot = plot
|
258
|
+
end
|
259
|
+
|
260
|
+
def __pyptr__
|
261
|
+
@plot
|
262
|
+
end
|
263
|
+
end
|
138
264
|
end
|
139
265
|
|
140
266
|
require_relative "wandb/xgboost_callback"
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: wandb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.5
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Brett Shollenberger
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-10-
|
11
|
+
date: 2024-10-16 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: pycall
|