wandb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: '0384522cf34b40d0a03a36f32799bb12cf1811b41688d35820638bb590fa7c3c'
4
+ data.tar.gz: 3a0fb90233c6545cf66e351d6b2694bd485cd63eb70fcbc7520c028829d0667d
5
+ SHA512:
6
+ metadata.gz: 4aeafd34db937f033b25c0585eaa97d055ca521d4bd0f96ac3b83803dea8c779461a24452c0a6cbcdb075f57c1999ab1443c6e88af27131260ac6ec890fd55c9
7
+ data.tar.gz: 906f599bd6a224bedb92d74cdd9972079b1e54bbc182d4cfdb79cc074085cdf3dcb827dc389b5f01bc59183a6e4381d4ae459d1fa830297c30654ea5068f2944
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Wandb
4
+ VERSION = "0.1.0"
5
+ end
@@ -0,0 +1,91 @@
1
+ module Wandb
2
+ class XGBoostCallback
3
+ MINIMIZE_METRICS = %w[rmse logloss error] # Add other metrics as needed
4
+ MAXIMIZE_METRICS = %w[auc accuracy] # Add other metrics as needed
5
+
6
+ def initialize(log_model: false, log_feature_importance: true, importance_type: "gain", define_metric: true)
7
+ @log_model = log_model
8
+ @log_feature_importance = log_feature_importance
9
+ @importance_type = importance_type
10
+ @define_metric = define_metric
11
+
12
+ return if Wandb.current_run
13
+
14
+ raise "You must call wandb.init() before WandbCallback()"
15
+ end
16
+
17
+ def before_training(model:)
18
+ # Update Wandb config with model configuration
19
+ Wandb.current_run.config = model.params
20
+ Wandb.log(model.params)
21
+ end
22
+
23
+ def after_training(model:)
24
+ # Log the model as an artifact
25
+ log_model_as_artifact(model) if @log_model
26
+
27
+ # Log feature importance
28
+ log_feature_importance(model) if @log_feature_importance
29
+
30
+ # Log best score and best iteration
31
+ return unless model.best_score
32
+
33
+ Wandb.log(
34
+ "best_score" => model.best_score.to_f,
35
+ "best_iteration" => model.best_iteration.to_i
36
+ )
37
+ end
38
+
39
+ def before_iteration(model:, epoch:, evals:)
40
+ # noop
41
+ end
42
+
43
+ def after_iteration(model:, epoch:, evals:, res:)
44
+ res.each do |metric_name, value|
45
+ data, metric = metric_name.split("-", 2)
46
+ full_metric_name = "#{data}-#{metric}"
47
+
48
+ if @define_metric
49
+ define_metric(data, metric)
50
+ Wandb.log({ full_metric_name => value })
51
+ else
52
+ Wandb.log({ full_metric_name => value })
53
+ end
54
+ end
55
+
56
+ Wandb.log({ "epoch" => epoch })
57
+ @define_metric = false
58
+ end
59
+
60
+ private
61
+
62
+ def log_model_as_artifact(model)
63
+ model_name = "#{Wandb.current_run.id}_model.json"
64
+ model_path = File.join(Wandb.current_run.dir, model_name)
65
+ model.save_model(model_path)
66
+
67
+ model_artifact = Wandb.Artifact(name: model_name, type: "model")
68
+ model_artifact.add_file(model_path)
69
+ Wandb.current_run.log_artifact(model_artifact)
70
+ end
71
+
72
+ def log_feature_importance(model)
73
+ fi = model.score(importance_type: @importance_type)
74
+ fi_data = fi.map { |k, v| [k, v] }
75
+
76
+ table = Wandb.Table(data: fi_data, columns: %w[Feature Importance])
77
+ bar_plot = Wandb.plot.bar(table, "Feature", "Importance", title: "Feature Importance")
78
+ Wandb.log({ "Feature Importance" => bar_plot })
79
+ end
80
+
81
+ def define_metric(data, metric_name)
82
+ full_metric_name = "#{data}-#{metric_name}"
83
+
84
+ if metric_name.downcase.include?("loss") || MINIMIZE_METRICS.include?(metric_name.downcase)
85
+ Wandb.define_metric(full_metric_name, summary: "min")
86
+ elsif MAXIMIZE_METRICS.include?(metric_name.downcase)
87
+ Wandb.define_metric(full_metric_name, summary: "max")
88
+ end
89
+ end
90
+ end
91
+ end
data/lib/wandb.rb ADDED
@@ -0,0 +1,139 @@
1
+ require_relative "wandb/version"
2
+ require "pry"
3
+ require "pycall/import"
4
+
5
+ # Ensure wandb executable isn't set using Ruby context
6
+ ENV["WANDB__EXECUTABLE"] = `which python3`.chomp.empty? ? `which python`.chomp : `which python3`.chomp
7
+ py_sys = PyCall.import_module("sys")
8
+ py_sys.executable = ENV["WANDB__EXECUTABLE"]
9
+
10
+ module Wandb
11
+ include PyCall::Import
12
+
13
+ class << self
14
+ # Lazy-load the wandb Python module
15
+ def __pyptr__
16
+ @wandb ||= PyCall.import_module("wandb")
17
+ end
18
+
19
+ def Table(*args, **kwargs)
20
+ __pyptr__.Table.new(*args, **kwargs)
21
+ end
22
+
23
+ # Expose wandb.plot
24
+ delegate :plot, to: :__pyptr__
25
+
26
+ # Expose define_metric
27
+ def define_metric(metric_name, **kwargs)
28
+ __pyptr__.define_metric(name: metric_name.force_encoding("UTF-8"), **kwargs)
29
+ end
30
+
31
+ # Expose wandb.Artifact
32
+ def Artifact(*args, **kwargs)
33
+ __pyptr__.Artifact.new(*args, **kwargs)
34
+ end
35
+
36
+ # Expose wandb.Error
37
+ delegate :Error, to: :__pyptr__
38
+
39
+ # Login to Wandb
40
+ def login(api_key: nil, **kwargs)
41
+ kwargs = kwargs.to_h
42
+ kwargs[:key] = api_key if api_key
43
+ __pyptr__.login(**kwargs)
44
+ end
45
+
46
+ # Initialize a new run
47
+ def init(**kwargs)
48
+ run = __pyptr__.init(**kwargs)
49
+ @current_run = Run.new(run)
50
+ end
51
+
52
+ # Get the current run
53
+ attr_reader :current_run
54
+
55
+ # Log metrics to the current run
56
+ def log(metrics = {})
57
+ raise "No active run. Call Wandb.init() first." unless @current_run
58
+
59
+ @current_run.log(metrics.symbolize_keys)
60
+ end
61
+
62
+ # Finish the current run
63
+ def finish
64
+ @current_run.finish if @current_run
65
+ @current_run = nil
66
+ __pyptr__.finish
67
+ end
68
+
69
+ # Access the Wandb API
70
+ def api
71
+ @api ||= Api.new(__pyptr__.Api.new)
72
+ end
73
+ end
74
+
75
+ # Run class
76
+ class Run
77
+ def initialize(run)
78
+ @run = run
79
+ end
80
+
81
+ def log(metrics = {})
82
+ metrics.symbolize_keys!
83
+ @run.log(metrics, {})
84
+ end
85
+
86
+ def finish
87
+ @run.finish
88
+ end
89
+
90
+ def name
91
+ @run.name
92
+ end
93
+
94
+ def name=(new_name)
95
+ @run.name = new_name
96
+ end
97
+
98
+ def config
99
+ @run.config
100
+ end
101
+
102
+ def config=(new_config)
103
+ @run.config.update(PyCall::Dict.new(new_config))
104
+ end
105
+ end
106
+
107
+ # Api class
108
+ class Api
109
+ def initialize(api)
110
+ @api = api
111
+ end
112
+
113
+ def projects(entity = nil)
114
+ projects = @api.projects(entity)
115
+ projects.map { |proj| Project.new(proj) }
116
+ end
117
+
118
+ def project(name, entity = nil)
119
+ proj = @api.project(name, entity)
120
+ Project.new(proj)
121
+ end
122
+ end
123
+
124
+ # Project class
125
+ class Project
126
+ def initialize(project)
127
+ @project = project
128
+ end
129
+
130
+ def name
131
+ @project.name
132
+ end
133
+
134
+ def description
135
+ @project.description
136
+ end
137
+ end
138
+
139
+ require_relative "wandb/xgboost_callback"
metadata ADDED
@@ -0,0 +1,61 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: wandb
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Brett Shollenberger
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2024-10-11 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: pycall
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: '0'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: '0'
27
+ description: Log model runs to Weights & Biases
28
+ email:
29
+ - brett.shollenberger@gmail.com
30
+ executables: []
31
+ extensions: []
32
+ extra_rdoc_files: []
33
+ files:
34
+ - lib/wandb.rb
35
+ - lib/wandb/version.rb
36
+ - lib/wandb/xgboost_callback.rb
37
+ homepage: https://github.com/brettshollenberger/wandb_rb.git
38
+ licenses:
39
+ - MIT
40
+ metadata:
41
+ homepage_uri: https://github.com/brettshollenberger/wandb_rb.git
42
+ post_install_message:
43
+ rdoc_options: []
44
+ require_paths:
45
+ - lib
46
+ required_ruby_version: !ruby/object:Gem::Requirement
47
+ requirements:
48
+ - - ">="
49
+ - !ruby/object:Gem::Version
50
+ version: '2.5'
51
+ required_rubygems_version: !ruby/object:Gem::Requirement
52
+ requirements:
53
+ - - ">="
54
+ - !ruby/object:Gem::Version
55
+ version: '0'
56
+ requirements: []
57
+ rubygems_version: 3.4.10
58
+ signing_key:
59
+ specification_version: 4
60
+ summary: A Ruby integration for the Weights & Biases platform
61
+ test_files: []