wandb 0.1.2 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/wandb/version.rb +1 -1
- data/lib/wandb/xgboost_callback.rb +80 -40
- data/lib/wandb.rb +131 -7
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 310ede46ccba30aa00bfbb81fbb8ffdefd818528c7ae09523e7dca9629258892
|
4
|
+
data.tar.gz: 07643a463e807e99db1a4e874bd85b749f6d65b644f15e0423505ead63ec7749
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 00e7a6fe5e888d931c1abbc6047207859a46b024c83a05a6e87feeea8d20576f2f3da1d6a5c851c9dd945e7bd564ecfc56b5c4707ed3b54284aea97ecb4dac46
|
7
|
+
data.tar.gz: d4b2c592d6a4a5e079024d48cb7a97834a883f1e5e12dbe1c856255d4beed84c90f54f629d818cad14513f4778b08657e6781c958ff977fe607d1a74c0a2deec
|
data/lib/wandb/version.rb
CHANGED
@@ -1,26 +1,53 @@
|
|
1
|
+
require "xgb"
|
2
|
+
require "tempfile"
|
3
|
+
require "fileutils"
|
4
|
+
|
1
5
|
module Wandb
|
2
|
-
class XGBoostCallback
|
6
|
+
class XGBoostCallback < XGBoost::TrainingCallback
|
3
7
|
MINIMIZE_METRICS = %w[rmse logloss error] # Add other metrics as needed
|
4
8
|
MAXIMIZE_METRICS = %w[auc accuracy] # Add other metrics as needed
|
5
9
|
|
6
|
-
|
7
|
-
|
8
|
-
@log_feature_importance = log_feature_importance
|
9
|
-
@importance_type = importance_type
|
10
|
-
@define_metric = define_metric
|
10
|
+
class Opts
|
11
|
+
attr_accessor :options
|
11
12
|
|
12
|
-
|
13
|
+
def initialize(options = {})
|
14
|
+
@options = options
|
15
|
+
end
|
13
16
|
|
14
|
-
|
17
|
+
def default(key, default)
|
18
|
+
options.key?(key) ? options[key] : default
|
19
|
+
end
|
15
20
|
end
|
16
21
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
22
|
+
attr_accessor :project_name, :api_key, :custom_loggers
|
23
|
+
|
24
|
+
def initialize(options = {})
|
25
|
+
options = Opts.new(options)
|
26
|
+
@log_model = options.default(:log_model, false)
|
27
|
+
@log_feature_importance = options.default(:log_feature_importance, true)
|
28
|
+
@importance_type = options.default(:importance_type, "gain")
|
29
|
+
@define_metric = options.default(:define_metric, true)
|
30
|
+
@api_key = options.default(:api_key, ENV["WANDB_API_KEY"])
|
31
|
+
@project_name = options.default(:project_name, nil)
|
32
|
+
@custom_loggers = options.default(:custom_loggers, [])
|
21
33
|
end
|
22
34
|
|
23
|
-
def
|
35
|
+
def before_training(model)
|
36
|
+
Wandb.login(api_key: api_key)
|
37
|
+
Wandb.init(project: project_name)
|
38
|
+
config = JSON.parse(model.save_config)
|
39
|
+
log_conf = {
|
40
|
+
learning_rate: config.dig("learner", "gradient_booster", "tree_train_param", "learning_rate").to_f,
|
41
|
+
max_depth: config.dig("learner", "gradient_booster", "tree_train_param", "max_depth").to_f,
|
42
|
+
n_estimators: model.num_boosted_rounds
|
43
|
+
}
|
44
|
+
Wandb.current_run.config = log_conf
|
45
|
+
|
46
|
+
Wandb.log(log_conf)
|
47
|
+
model
|
48
|
+
end
|
49
|
+
|
50
|
+
def after_training(model)
|
24
51
|
# Log the model as an artifact
|
25
52
|
log_model_as_artifact(model) if @log_model
|
26
53
|
|
@@ -28,54 +55,67 @@ module Wandb
|
|
28
55
|
log_feature_importance(model) if @log_feature_importance
|
29
56
|
|
30
57
|
# Log best score and best iteration
|
31
|
-
|
58
|
+
unless model.best_score
|
59
|
+
finish
|
60
|
+
return model
|
61
|
+
end
|
32
62
|
|
33
63
|
Wandb.log(
|
34
64
|
"best_score" => model.best_score.to_f,
|
35
65
|
"best_iteration" => model.best_iteration.to_i
|
36
66
|
)
|
67
|
+
finish
|
68
|
+
|
69
|
+
model
|
37
70
|
end
|
38
71
|
|
39
|
-
def
|
40
|
-
|
72
|
+
def finish
|
73
|
+
Wandb.finish
|
74
|
+
FileUtils.rm_rf(File.join(Dir.pwd, "wandb"))
|
41
75
|
end
|
42
76
|
|
43
|
-
def
|
44
|
-
|
45
|
-
|
46
|
-
full_metric_name = "#{data}-#{metric}"
|
47
|
-
|
48
|
-
if @define_metric
|
49
|
-
define_metric(data, metric)
|
50
|
-
Wandb.log({ full_metric_name => value })
|
51
|
-
else
|
52
|
-
Wandb.log({ full_metric_name => value })
|
53
|
-
end
|
54
|
-
end
|
77
|
+
def before_iteration(_model, _epoch, _history)
|
78
|
+
false
|
79
|
+
end
|
55
80
|
|
56
|
-
|
57
|
-
|
81
|
+
def after_iteration(model, epoch, history)
|
82
|
+
history.each do |split, metric_scores|
|
83
|
+
metric = metric_scores.keys.first
|
84
|
+
values = metric_scores.values.last
|
85
|
+
epoch_value = values[epoch]
|
86
|
+
|
87
|
+
define_metric(split, metric) if @define_metric && epoch == 0
|
88
|
+
full_metric_name = "#{split}-#{metric}"
|
89
|
+
Wandb.log({ full_metric_name => epoch_value })
|
90
|
+
end
|
91
|
+
@custom_loggers.each do |logger|
|
92
|
+
logger.call(model, epoch, history)
|
93
|
+
end
|
94
|
+
Wandb.log("epoch" => epoch)
|
95
|
+
false
|
58
96
|
end
|
59
97
|
|
60
98
|
private
|
61
99
|
|
62
100
|
def log_model_as_artifact(model)
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
101
|
+
Dir.mktmpdir("wandb_xgboost_model") do |tmp_dir|
|
102
|
+
model_name = "model.json"
|
103
|
+
model_path = File.join(tmp_dir, model_name)
|
104
|
+
model.save_model(model_path)
|
105
|
+
|
106
|
+
model_artifact = Wandb.artifact(name: model_name, type: "model")
|
107
|
+
model_artifact.add_file(model_path)
|
108
|
+
Wandb.current_run.log_artifact(model_artifact)
|
109
|
+
end
|
70
110
|
end
|
71
111
|
|
72
112
|
def log_feature_importance(model)
|
73
113
|
fi = model.score(importance_type: @importance_type)
|
74
114
|
fi_data = fi.map { |k, v| [k, v] }
|
75
115
|
|
76
|
-
table = Wandb.
|
77
|
-
bar_plot = Wandb.
|
78
|
-
Wandb.log({ "Feature Importance" => bar_plot })
|
116
|
+
table = Wandb::Table.new(data: fi_data, columns: %w[Feature Importance])
|
117
|
+
bar_plot = Wandb::Plot.bar(table.table, "Feature", "Importance", title: "Feature Importance")
|
118
|
+
Wandb.log({ "Feature Importance" => bar_plot.__pyptr__ })
|
79
119
|
end
|
80
120
|
|
81
121
|
def define_metric(data, metric_name)
|
data/lib/wandb.rb
CHANGED
@@ -16,10 +16,6 @@ module Wandb
|
|
16
16
|
@wandb ||= PyCall.import_module("wandb")
|
17
17
|
end
|
18
18
|
|
19
|
-
def Table(*args, **kwargs)
|
20
|
-
__pyptr__.Table.new(*args, **kwargs)
|
21
|
-
end
|
22
|
-
|
23
19
|
def plot(*args, **kwargs)
|
24
20
|
__pyptr__.plot(*args, **kwargs)
|
25
21
|
end
|
@@ -30,11 +26,12 @@ module Wandb
|
|
30
26
|
end
|
31
27
|
|
32
28
|
# Expose wandb.Artifact
|
33
|
-
def
|
34
|
-
__pyptr__.Artifact.new(*args, **kwargs)
|
29
|
+
def artifact(*args, **kwargs)
|
30
|
+
py_artifact = __pyptr__.Artifact.new(*args, **kwargs)
|
31
|
+
Artifact.new(py_artifact)
|
35
32
|
end
|
36
33
|
|
37
|
-
def
|
34
|
+
def error
|
38
35
|
__pyptr__.Error
|
39
36
|
end
|
40
37
|
|
@@ -72,6 +69,10 @@ module Wandb
|
|
72
69
|
def api
|
73
70
|
@api ||= Api.new(__pyptr__.Api.new)
|
74
71
|
end
|
72
|
+
|
73
|
+
def plot
|
74
|
+
Plot
|
75
|
+
end
|
75
76
|
end
|
76
77
|
|
77
78
|
# Run class
|
@@ -80,6 +81,10 @@ module Wandb
|
|
80
81
|
@run = run
|
81
82
|
end
|
82
83
|
|
84
|
+
def run_id
|
85
|
+
@run.run_id
|
86
|
+
end
|
87
|
+
|
83
88
|
def log(metrics = {})
|
84
89
|
metrics.symbolize_keys!
|
85
90
|
@run.log(metrics, {})
|
@@ -104,6 +109,53 @@ module Wandb
|
|
104
109
|
def config=(new_config)
|
105
110
|
@run.config.update(PyCall::Dict.new(new_config))
|
106
111
|
end
|
112
|
+
|
113
|
+
def log_artifact(artifact)
|
114
|
+
@run.log_artifact(artifact.__pyptr__)
|
115
|
+
end
|
116
|
+
end
|
117
|
+
|
118
|
+
# Artifact class
|
119
|
+
class Artifact
|
120
|
+
def initialize(artifact)
|
121
|
+
@artifact = artifact
|
122
|
+
end
|
123
|
+
|
124
|
+
def __pyptr__
|
125
|
+
@artifact
|
126
|
+
end
|
127
|
+
|
128
|
+
def name
|
129
|
+
@artifact.name
|
130
|
+
end
|
131
|
+
|
132
|
+
def type
|
133
|
+
@artifact.type
|
134
|
+
end
|
135
|
+
|
136
|
+
def add_file(local_path, name = nil)
|
137
|
+
@artifact.add_file(local_path, name)
|
138
|
+
end
|
139
|
+
|
140
|
+
def add_dir(local_dir, name = nil)
|
141
|
+
@artifact.add_dir(local_dir, name)
|
142
|
+
end
|
143
|
+
|
144
|
+
def get_path(name)
|
145
|
+
@artifact.get_path(name)
|
146
|
+
end
|
147
|
+
|
148
|
+
def metadata
|
149
|
+
@artifact.metadata
|
150
|
+
end
|
151
|
+
|
152
|
+
def metadata=(new_metadata)
|
153
|
+
@artifact.metadata = new_metadata
|
154
|
+
end
|
155
|
+
|
156
|
+
def save
|
157
|
+
@artifact.save
|
158
|
+
end
|
107
159
|
end
|
108
160
|
|
109
161
|
# Api class
|
@@ -137,6 +189,78 @@ module Wandb
|
|
137
189
|
@project.description
|
138
190
|
end
|
139
191
|
end
|
192
|
+
|
193
|
+
# Table class
|
194
|
+
class Table
|
195
|
+
attr_accessor :table, :data, :columns
|
196
|
+
|
197
|
+
def initialize(data: {}, columns: [])
|
198
|
+
@table = Wandb.__pyptr__.Table.new(data: data, columns: columns)
|
199
|
+
@data = data
|
200
|
+
@columns = columns
|
201
|
+
end
|
202
|
+
|
203
|
+
def __pyptr__
|
204
|
+
@table
|
205
|
+
end
|
206
|
+
|
207
|
+
def add_data(*args)
|
208
|
+
@table.add_data(*args)
|
209
|
+
end
|
210
|
+
|
211
|
+
def add_column(name, data)
|
212
|
+
@table.add_column(name, data)
|
213
|
+
end
|
214
|
+
|
215
|
+
def get_column(name)
|
216
|
+
@table.get_column(name)
|
217
|
+
end
|
218
|
+
|
219
|
+
def columns
|
220
|
+
@table.columns
|
221
|
+
end
|
222
|
+
|
223
|
+
def data
|
224
|
+
@table.data
|
225
|
+
end
|
226
|
+
|
227
|
+
def to_pandas
|
228
|
+
@table.get_dataframe
|
229
|
+
end
|
230
|
+
end
|
231
|
+
|
232
|
+
# Plot class
|
233
|
+
class Plot
|
234
|
+
class << self
|
235
|
+
def bar(table, x_key, y_key, title: nil)
|
236
|
+
py_plot = Wandb.__pyptr__.plot.bar(table.__pyptr__, x_key, y_key, title: title)
|
237
|
+
new(py_plot)
|
238
|
+
end
|
239
|
+
|
240
|
+
def line(table, x_key, y_key, title: nil)
|
241
|
+
py_plot = Wandb.__pyptr__.plot.line(table.__pyptr__, x_key, y_key, title: title)
|
242
|
+
new(py_plot)
|
243
|
+
end
|
244
|
+
|
245
|
+
def scatter(table, x_key, y_key, title: nil)
|
246
|
+
py_plot = Wandb.__pyptr__.plot.scatter(table.__pyptr__, x_key, y_key, title: title)
|
247
|
+
new(py_plot)
|
248
|
+
end
|
249
|
+
|
250
|
+
def histogram(table, value_key, title: nil)
|
251
|
+
py_plot = Wandb.__pyptr__.plot.histogram(table.__pyptr__, value_key, title: title)
|
252
|
+
new(py_plot)
|
253
|
+
end
|
254
|
+
end
|
255
|
+
|
256
|
+
def initialize(plot)
|
257
|
+
@plot = plot
|
258
|
+
end
|
259
|
+
|
260
|
+
def __pyptr__
|
261
|
+
@plot
|
262
|
+
end
|
263
|
+
end
|
140
264
|
end
|
141
265
|
|
142
266
|
require_relative "wandb/xgboost_callback"
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: wandb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Brett Shollenberger
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-10-
|
11
|
+
date: 2024-10-18 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: pycall
|