wandb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/lib/wandb/version.rb +5 -0
- data/lib/wandb/xgboost_callback.rb +91 -0
- data/lib/wandb.rb +139 -0
- metadata +61 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: '0384522cf34b40d0a03a36f32799bb12cf1811b41688d35820638bb590fa7c3c'
|
4
|
+
data.tar.gz: 3a0fb90233c6545cf66e351d6b2694bd485cd63eb70fcbc7520c028829d0667d
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 4aeafd34db937f033b25c0585eaa97d055ca521d4bd0f96ac3b83803dea8c779461a24452c0a6cbcdb075f57c1999ab1443c6e88af27131260ac6ec890fd55c9
|
7
|
+
data.tar.gz: 906f599bd6a224bedb92d74cdd9972079b1e54bbc182d4cfdb79cc074085cdf3dcb827dc389b5f01bc59183a6e4381d4ae459d1fa830297c30654ea5068f2944
|
@@ -0,0 +1,91 @@
|
|
1
|
+
module Wandb
|
2
|
+
class XGBoostCallback
|
3
|
+
MINIMIZE_METRICS = %w[rmse logloss error] # Add other metrics as needed
|
4
|
+
MAXIMIZE_METRICS = %w[auc accuracy] # Add other metrics as needed
|
5
|
+
|
6
|
+
def initialize(log_model: false, log_feature_importance: true, importance_type: "gain", define_metric: true)
|
7
|
+
@log_model = log_model
|
8
|
+
@log_feature_importance = log_feature_importance
|
9
|
+
@importance_type = importance_type
|
10
|
+
@define_metric = define_metric
|
11
|
+
|
12
|
+
return if Wandb.current_run
|
13
|
+
|
14
|
+
raise "You must call wandb.init() before WandbCallback()"
|
15
|
+
end
|
16
|
+
|
17
|
+
def before_training(model:)
|
18
|
+
# Update Wandb config with model configuration
|
19
|
+
Wandb.current_run.config = model.params
|
20
|
+
Wandb.log(model.params)
|
21
|
+
end
|
22
|
+
|
23
|
+
def after_training(model:)
|
24
|
+
# Log the model as an artifact
|
25
|
+
log_model_as_artifact(model) if @log_model
|
26
|
+
|
27
|
+
# Log feature importance
|
28
|
+
log_feature_importance(model) if @log_feature_importance
|
29
|
+
|
30
|
+
# Log best score and best iteration
|
31
|
+
return unless model.best_score
|
32
|
+
|
33
|
+
Wandb.log(
|
34
|
+
"best_score" => model.best_score.to_f,
|
35
|
+
"best_iteration" => model.best_iteration.to_i
|
36
|
+
)
|
37
|
+
end
|
38
|
+
|
39
|
+
def before_iteration(model:, epoch:, evals:)
|
40
|
+
# noop
|
41
|
+
end
|
42
|
+
|
43
|
+
def after_iteration(model:, epoch:, evals:, res:)
|
44
|
+
res.each do |metric_name, value|
|
45
|
+
data, metric = metric_name.split("-", 2)
|
46
|
+
full_metric_name = "#{data}-#{metric}"
|
47
|
+
|
48
|
+
if @define_metric
|
49
|
+
define_metric(data, metric)
|
50
|
+
Wandb.log({ full_metric_name => value })
|
51
|
+
else
|
52
|
+
Wandb.log({ full_metric_name => value })
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
Wandb.log({ "epoch" => epoch })
|
57
|
+
@define_metric = false
|
58
|
+
end
|
59
|
+
|
60
|
+
private
|
61
|
+
|
62
|
+
def log_model_as_artifact(model)
|
63
|
+
model_name = "#{Wandb.current_run.id}_model.json"
|
64
|
+
model_path = File.join(Wandb.current_run.dir, model_name)
|
65
|
+
model.save_model(model_path)
|
66
|
+
|
67
|
+
model_artifact = Wandb.Artifact(name: model_name, type: "model")
|
68
|
+
model_artifact.add_file(model_path)
|
69
|
+
Wandb.current_run.log_artifact(model_artifact)
|
70
|
+
end
|
71
|
+
|
72
|
+
def log_feature_importance(model)
|
73
|
+
fi = model.score(importance_type: @importance_type)
|
74
|
+
fi_data = fi.map { |k, v| [k, v] }
|
75
|
+
|
76
|
+
table = Wandb.Table(data: fi_data, columns: %w[Feature Importance])
|
77
|
+
bar_plot = Wandb.plot.bar(table, "Feature", "Importance", title: "Feature Importance")
|
78
|
+
Wandb.log({ "Feature Importance" => bar_plot })
|
79
|
+
end
|
80
|
+
|
81
|
+
def define_metric(data, metric_name)
|
82
|
+
full_metric_name = "#{data}-#{metric_name}"
|
83
|
+
|
84
|
+
if metric_name.downcase.include?("loss") || MINIMIZE_METRICS.include?(metric_name.downcase)
|
85
|
+
Wandb.define_metric(full_metric_name, summary: "min")
|
86
|
+
elsif MAXIMIZE_METRICS.include?(metric_name.downcase)
|
87
|
+
Wandb.define_metric(full_metric_name, summary: "max")
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
data/lib/wandb.rb
ADDED
@@ -0,0 +1,139 @@
|
|
1
|
+
require_relative "wandb/version"
|
2
|
+
require "pry"
|
3
|
+
require "pycall/import"
|
4
|
+
|
5
|
+
# Ensure wandb executable isn't set using Ruby context
|
6
|
+
ENV["WANDB__EXECUTABLE"] = `which python3`.chomp.empty? ? `which python`.chomp : `which python3`.chomp
|
7
|
+
py_sys = PyCall.import_module("sys")
|
8
|
+
py_sys.executable = ENV["WANDB__EXECUTABLE"]
|
9
|
+
|
10
|
+
module Wandb
|
11
|
+
include PyCall::Import
|
12
|
+
|
13
|
+
class << self
|
14
|
+
# Lazy-load the wandb Python module
|
15
|
+
def __pyptr__
|
16
|
+
@wandb ||= PyCall.import_module("wandb")
|
17
|
+
end
|
18
|
+
|
19
|
+
def Table(*args, **kwargs)
|
20
|
+
__pyptr__.Table.new(*args, **kwargs)
|
21
|
+
end
|
22
|
+
|
23
|
+
# Expose wandb.plot
|
24
|
+
delegate :plot, to: :__pyptr__
|
25
|
+
|
26
|
+
# Expose define_metric
|
27
|
+
def define_metric(metric_name, **kwargs)
|
28
|
+
__pyptr__.define_metric(name: metric_name.force_encoding("UTF-8"), **kwargs)
|
29
|
+
end
|
30
|
+
|
31
|
+
# Expose wandb.Artifact
|
32
|
+
def Artifact(*args, **kwargs)
|
33
|
+
__pyptr__.Artifact.new(*args, **kwargs)
|
34
|
+
end
|
35
|
+
|
36
|
+
# Expose wandb.Error
|
37
|
+
delegate :Error, to: :__pyptr__
|
38
|
+
|
39
|
+
# Login to Wandb
|
40
|
+
def login(api_key: nil, **kwargs)
|
41
|
+
kwargs = kwargs.to_h
|
42
|
+
kwargs[:key] = api_key if api_key
|
43
|
+
__pyptr__.login(**kwargs)
|
44
|
+
end
|
45
|
+
|
46
|
+
# Initialize a new run
|
47
|
+
def init(**kwargs)
|
48
|
+
run = __pyptr__.init(**kwargs)
|
49
|
+
@current_run = Run.new(run)
|
50
|
+
end
|
51
|
+
|
52
|
+
# Get the current run
|
53
|
+
attr_reader :current_run
|
54
|
+
|
55
|
+
# Log metrics to the current run
|
56
|
+
def log(metrics = {})
|
57
|
+
raise "No active run. Call Wandb.init() first." unless @current_run
|
58
|
+
|
59
|
+
@current_run.log(metrics.symbolize_keys)
|
60
|
+
end
|
61
|
+
|
62
|
+
# Finish the current run
|
63
|
+
def finish
|
64
|
+
@current_run.finish if @current_run
|
65
|
+
@current_run = nil
|
66
|
+
__pyptr__.finish
|
67
|
+
end
|
68
|
+
|
69
|
+
# Access the Wandb API
|
70
|
+
def api
|
71
|
+
@api ||= Api.new(__pyptr__.Api.new)
|
72
|
+
end
|
73
|
+
end
|
74
|
+
|
75
|
+
# Run class
|
76
|
+
class Run
|
77
|
+
def initialize(run)
|
78
|
+
@run = run
|
79
|
+
end
|
80
|
+
|
81
|
+
def log(metrics = {})
|
82
|
+
metrics.symbolize_keys!
|
83
|
+
@run.log(metrics, {})
|
84
|
+
end
|
85
|
+
|
86
|
+
def finish
|
87
|
+
@run.finish
|
88
|
+
end
|
89
|
+
|
90
|
+
def name
|
91
|
+
@run.name
|
92
|
+
end
|
93
|
+
|
94
|
+
def name=(new_name)
|
95
|
+
@run.name = new_name
|
96
|
+
end
|
97
|
+
|
98
|
+
def config
|
99
|
+
@run.config
|
100
|
+
end
|
101
|
+
|
102
|
+
def config=(new_config)
|
103
|
+
@run.config.update(PyCall::Dict.new(new_config))
|
104
|
+
end
|
105
|
+
end
|
106
|
+
|
107
|
+
# Api class
|
108
|
+
class Api
|
109
|
+
def initialize(api)
|
110
|
+
@api = api
|
111
|
+
end
|
112
|
+
|
113
|
+
def projects(entity = nil)
|
114
|
+
projects = @api.projects(entity)
|
115
|
+
projects.map { |proj| Project.new(proj) }
|
116
|
+
end
|
117
|
+
|
118
|
+
def project(name, entity = nil)
|
119
|
+
proj = @api.project(name, entity)
|
120
|
+
Project.new(proj)
|
121
|
+
end
|
122
|
+
end
|
123
|
+
|
124
|
+
# Project class
|
125
|
+
class Project
|
126
|
+
def initialize(project)
|
127
|
+
@project = project
|
128
|
+
end
|
129
|
+
|
130
|
+
def name
|
131
|
+
@project.name
|
132
|
+
end
|
133
|
+
|
134
|
+
def description
|
135
|
+
@project.description
|
136
|
+
end
|
137
|
+
end
|
138
|
+
|
139
|
+
require_relative "wandb/xgboost_callback"
|
metadata
ADDED
@@ -0,0 +1,61 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: wandb
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Brett Shollenberger
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2024-10-11 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: pycall
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '0'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0'
|
27
|
+
description: Log model runs to Weights & Biases
|
28
|
+
email:
|
29
|
+
- brett.shollenberger@gmail.com
|
30
|
+
executables: []
|
31
|
+
extensions: []
|
32
|
+
extra_rdoc_files: []
|
33
|
+
files:
|
34
|
+
- lib/wandb.rb
|
35
|
+
- lib/wandb/version.rb
|
36
|
+
- lib/wandb/xgboost_callback.rb
|
37
|
+
homepage: https://github.com/brettshollenberger/wandb_rb.git
|
38
|
+
licenses:
|
39
|
+
- MIT
|
40
|
+
metadata:
|
41
|
+
homepage_uri: https://github.com/brettshollenberger/wandb_rb.git
|
42
|
+
post_install_message:
|
43
|
+
rdoc_options: []
|
44
|
+
require_paths:
|
45
|
+
- lib
|
46
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
47
|
+
requirements:
|
48
|
+
- - ">="
|
49
|
+
- !ruby/object:Gem::Version
|
50
|
+
version: '2.5'
|
51
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
52
|
+
requirements:
|
53
|
+
- - ">="
|
54
|
+
- !ruby/object:Gem::Version
|
55
|
+
version: '0'
|
56
|
+
requirements: []
|
57
|
+
rubygems_version: 3.4.10
|
58
|
+
signing_key:
|
59
|
+
specification_version: 4
|
60
|
+
summary: A Ruby integration for the Weights & Biases platform
|
61
|
+
test_files: []
|