wandb 0.1.6 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/wandb/version.rb +1 -1
- data/lib/wandb/xgboost_callback.rb +25 -13
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 17b68d037ed6930865fae0bd24437e1c41d3a746208c69237bf2cd9da3ed2c4e
|
4
|
+
data.tar.gz: be8a03c3963951579dadd5397780cdd1d003f0b2e7e83f6a181b6dda491267de
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: e7e1edea11690670d9d1af572c99caea8d6985fa31bd3e6255602352081d794bc00690a95042bf301588b160952de808be7f242b4f25cf6c54a8d146e966ccdb
|
7
|
+
data.tar.gz: 1c3120cf18ba4ee4cc2cd3e26f61cccc296cc7059b823f0d8892a6b27571a7de120a8d3378151010793b42e99539e722c18a7aae363af465cbc53da26098b7a1
|
data/lib/wandb/version.rb
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
require "xgb"
|
2
2
|
require "tempfile"
|
3
3
|
require "fileutils"
|
4
|
+
require "xgboost/training_callback"
|
4
5
|
|
5
6
|
module Wandb
|
6
7
|
class XGBoostCallback < XGBoost::TrainingCallback
|
@@ -19,16 +20,18 @@ module Wandb
|
|
19
20
|
end
|
20
21
|
end
|
21
22
|
|
22
|
-
attr_accessor :project_name, :api_key, :custom_loggers
|
23
|
+
attr_accessor :project_name, :api_key, :custom_loggers, :history, :sample
|
23
24
|
|
24
25
|
def initialize(options = {})
|
25
26
|
options = Opts.new(options)
|
26
27
|
@log_model = options.default(:log_model, false)
|
27
28
|
@log_feature_importance = options.default(:log_feature_importance, true)
|
28
29
|
@importance_type = options.default(:importance_type, "gain")
|
30
|
+
@normalize_feature_importance = options.default(:normalize_feature_importance, true)
|
29
31
|
@define_metric = options.default(:define_metric, true)
|
30
32
|
@api_key = options.default(:api_key, ENV["WANDB_API_KEY"])
|
31
33
|
@project_name = options.default(:project_name, nil)
|
34
|
+
@sample = options.default(:sample, 1.0)
|
32
35
|
@custom_loggers = options.default(:custom_loggers, [])
|
33
36
|
end
|
34
37
|
|
@@ -79,19 +82,22 @@ module Wandb
|
|
79
82
|
end
|
80
83
|
|
81
84
|
def after_iteration(model, epoch, history)
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
85
|
+
log_frequency = (1.0 / @sample).round
|
86
|
+
if epoch % log_frequency == 0
|
87
|
+
history.to_h.each do |split, metric_scores|
|
88
|
+
metric = metric_scores.keys.first
|
89
|
+
values = metric_scores.values.last
|
90
|
+
epoch_value = values[epoch]
|
91
|
+
|
92
|
+
define_metric(split, metric) if @define_metric && epoch == 0
|
93
|
+
full_metric_name = "#{split}-#{metric}"
|
94
|
+
Wandb.log({ full_metric_name => epoch_value })
|
95
|
+
end
|
96
|
+
@custom_loggers.each do |logger|
|
97
|
+
logger.call(model, epoch, history)
|
98
|
+
end
|
99
|
+
Wandb.log("epoch" => epoch)
|
90
100
|
end
|
91
|
-
@custom_loggers.each do |logger|
|
92
|
-
logger.call(model, epoch, history)
|
93
|
-
end
|
94
|
-
Wandb.log("epoch" => epoch)
|
95
101
|
false
|
96
102
|
end
|
97
103
|
|
@@ -111,6 +117,12 @@ module Wandb
|
|
111
117
|
|
112
118
|
def log_feature_importance(model)
|
113
119
|
fi = model.score(importance_type: @importance_type)
|
120
|
+
|
121
|
+
if @normalize_feature_importance
|
122
|
+
total_importance = fi.values.sum
|
123
|
+
fi = fi.transform_values { |v| v / total_importance }
|
124
|
+
end
|
125
|
+
|
114
126
|
fi_data = fi.map { |k, v| [k, v] }
|
115
127
|
|
116
128
|
table = Wandb::Table.new(data: fi_data, columns: %w[Feature Importance])
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: wandb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.8
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Brett Shollenberger
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-10-
|
11
|
+
date: 2024-10-20 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: pycall
|