classic_bandit 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rubocop.yml +3 -0
- data/Rakefile +3 -1
- data/example/Gemfile +12 -0
- data/example/Gemfile.lock +20 -0
- data/example/beta_random.rb +81 -0
- data/example/simulation.rb +73 -0
- data/lib/classic_bandit/version.rb +1 -1
- metadata +6 -2
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: adf0483eb642e9d50a1265baeb748665a45209a1bea1ed1e101a134936236faa
|
|
4
|
+
data.tar.gz: 3f4cff2ed59733a2694af2387f91a0fccfdbe30e26d9879984533815abc7b3fd
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: f1abb6ea6d2c7aa56648242e318ae6ce02e5047cf13fee0dc01412030716629fef088f44b849bce50a45590f9e284a80230a3a5289b9fa1f4c65267bdaefc838
|
|
7
|
+
data.tar.gz: 49abf9eba86caefb27b3f4fd067d9aa20865e9bac066baca682a01ee34e3273883cf9113f36ebe6f47376c029377fd69749e3b7bd9ba477468a253ce9df5b422
|
data/.rubocop.yml
CHANGED
data/Rakefile
CHANGED
data/example/Gemfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
GEM
|
|
2
|
+
remote: https://rubygems.org/
|
|
3
|
+
specs:
|
|
4
|
+
classic_bandit (0.1.0)
|
|
5
|
+
gnuplot (2.6.2)
|
|
6
|
+
matrix (0.4.2)
|
|
7
|
+
zeitwerk (2.7.1)
|
|
8
|
+
|
|
9
|
+
PLATFORMS
|
|
10
|
+
ruby
|
|
11
|
+
x86_64-linux
|
|
12
|
+
|
|
13
|
+
DEPENDENCIES
|
|
14
|
+
classic_bandit (~> 0.1.0)
|
|
15
|
+
gnuplot (~> 2.6)
|
|
16
|
+
matrix (~> 0.4.2)
|
|
17
|
+
zeitwerk (~> 2.7)
|
|
18
|
+
|
|
19
|
+
BUNDLED WITH
|
|
20
|
+
2.6.2
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "gnuplot"
|
|
4
|
+
|
|
5
|
+
def gamma_random(alpha)
|
|
6
|
+
return gamma_random(alpha + 1) * rand**(1.0 / alpha) if alpha < 1
|
|
7
|
+
|
|
8
|
+
# Marsaglia-Tsang method
|
|
9
|
+
d = alpha - 1.0 / 3
|
|
10
|
+
c = 1.0 / Math.sqrt(9 * d)
|
|
11
|
+
|
|
12
|
+
loop do
|
|
13
|
+
z = normal_random
|
|
14
|
+
v = (1 + c * z)**3
|
|
15
|
+
u = rand
|
|
16
|
+
|
|
17
|
+
return d * v if z > -1.0 / c && Math.log(u) < 0.5 * z * z + d * (1 - v + Math.log(v))
|
|
18
|
+
end
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def normal_random
|
|
22
|
+
r = Math.sqrt(-2 * Math.log(rand))
|
|
23
|
+
theta = 2 * Math::PI * rand
|
|
24
|
+
r * Math.cos(theta)
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def beta_function(alpha, beta)
|
|
28
|
+
gamma_alpha = Math.lgamma(alpha)[0]
|
|
29
|
+
gamma_beta = Math.lgamma(beta)[0]
|
|
30
|
+
gamma_apb = Math.lgamma(alpha + beta)[0]
|
|
31
|
+
Math.exp(gamma_alpha + gamma_beta - gamma_apb)
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
def beta_pdf(x, alpha, beta)
|
|
35
|
+
return 0 if x <= 0 || x >= 1
|
|
36
|
+
|
|
37
|
+
x**(alpha - 1) * (1 - x)**(beta - 1) / beta_function(alpha, beta)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
data = Array.new(10_000) do
|
|
41
|
+
x1 = gamma_random(41)
|
|
42
|
+
x2 = gamma_random(61)
|
|
43
|
+
x1 / (x1 + x2)
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
Gnuplot.open do |gp|
|
|
47
|
+
Gnuplot::Plot.new(gp) do |plot|
|
|
48
|
+
plot.title "Beta distribution histogram"
|
|
49
|
+
plot.xlabel "Value"
|
|
50
|
+
plot.ylabel "Frequency"
|
|
51
|
+
|
|
52
|
+
min_val = 0.0
|
|
53
|
+
max_val = 1.0
|
|
54
|
+
bin_count = 60.0
|
|
55
|
+
bin_width = (max_val - min_val) / bin_count
|
|
56
|
+
|
|
57
|
+
plot.xrange "[0:1]"
|
|
58
|
+
total_count = data.length.to_f
|
|
59
|
+
|
|
60
|
+
plot.set "style data histograms"
|
|
61
|
+
plot.set "style fill solid 0.5"
|
|
62
|
+
|
|
63
|
+
bins = Hash.new(0)
|
|
64
|
+
bin_count.to_i.times.each { |i| bins[i * bin_width] = 0 }
|
|
65
|
+
data.each { |v| bins[(v / bin_width).floor * bin_width] += 1 }
|
|
66
|
+
bins.transform_values! { |v| v / (total_count * bin_width) }
|
|
67
|
+
|
|
68
|
+
plot.data << Gnuplot::DataSet.new([bins.keys, bins.values]) do |ds|
|
|
69
|
+
ds.with = "boxes"
|
|
70
|
+
ds.title = "Empirical"
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
x_points = (0..100).map { |i| i / 100.0 }
|
|
74
|
+
y_points = x_points.map { |x| beta_pdf(x, 41, 61) }
|
|
75
|
+
plot.data << Gnuplot::DataSet.new([x_points, y_points]) do |ds|
|
|
76
|
+
ds.with = "lines"
|
|
77
|
+
ds.linewidth = 2
|
|
78
|
+
ds.title = "Theoretical PDF"
|
|
79
|
+
end
|
|
80
|
+
end
|
|
81
|
+
end
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'classic_bandit'
|
|
4
|
+
require 'gnuplot'
|
|
5
|
+
|
|
6
|
+
bandits = {
|
|
7
|
+
"UCB1" => ClassicBandit::Ucb1.new(arms: [
|
|
8
|
+
ClassicBandit::Arm.new(id: 0, trials: 1000, successes: 120),
|
|
9
|
+
ClassicBandit::Arm.new(id: 1, trials: 1000, successes: 110),
|
|
10
|
+
ClassicBandit::Arm.new(id: 2, trials: 1000, successes: 100),
|
|
11
|
+
]),
|
|
12
|
+
"Thompson sampling" => ClassicBandit::ThompsonSampling.new(arms: [
|
|
13
|
+
ClassicBandit::Arm.new(id: 0, trials: 1000, successes: 120),
|
|
14
|
+
ClassicBandit::Arm.new(id: 1, trials: 1000, successes: 110),
|
|
15
|
+
ClassicBandit::Arm.new(id: 2, trials: 1000, successes: 100),
|
|
16
|
+
])
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
arm0_counts = Hash.new(0)
|
|
20
|
+
arm0_probs = {}
|
|
21
|
+
bandits.keys.each { |key| arm0_probs[key] = [] }
|
|
22
|
+
x_values = []
|
|
23
|
+
|
|
24
|
+
10000.times.each do |i|
|
|
25
|
+
bandits.each do |key, bandit|
|
|
26
|
+
# 最初の500回はランダム
|
|
27
|
+
if i < 500
|
|
28
|
+
arm = bandit.arms.sample
|
|
29
|
+
else
|
|
30
|
+
arm = bandit.select_arm
|
|
31
|
+
end
|
|
32
|
+
reward = rand <= arm.mean_reward ? 1 : 0
|
|
33
|
+
bandit.update(arm, reward)
|
|
34
|
+
|
|
35
|
+
if arm.id == 0
|
|
36
|
+
arm0_counts[key] += 1
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
arm0_prob = arm0_counts[key].to_f / (i + 1)
|
|
40
|
+
arm0_probs[key] << arm0_prob
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
x_values << i + 1
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
Gnuplot.open do |gp|
|
|
47
|
+
Gnuplot::Plot.new(gp) do |plot|
|
|
48
|
+
plot.title "Bandit Selection Probability"
|
|
49
|
+
plot.xlabel "Iterations"
|
|
50
|
+
plot.ylabel "Probability"
|
|
51
|
+
|
|
52
|
+
# y軸の範囲を0-1に設定
|
|
53
|
+
plot.yrange "[0:1]"
|
|
54
|
+
|
|
55
|
+
# グリッドを表示
|
|
56
|
+
plot.set "grid"
|
|
57
|
+
|
|
58
|
+
# 線のスタイルを設定
|
|
59
|
+
plot.set "style line 1 linecolor rgb '#0060ad' linewidth 2"
|
|
60
|
+
plot.set "style line 2 linecolor rgb '#dd181f' linewidth 2"
|
|
61
|
+
|
|
62
|
+
# 各アルゴリズムのデータをプロット
|
|
63
|
+
colors = ["#0060ad", "#dd181f"]
|
|
64
|
+
bandits.each_with_index do |(key, _), index|
|
|
65
|
+
plot.data << Gnuplot::DataSet.new([x_values, arm0_probs[key]]) do |ds|
|
|
66
|
+
ds.with = "lines"
|
|
67
|
+
ds.linewidth = 2
|
|
68
|
+
ds.linecolor = "rgb '#{colors[index]}'"
|
|
69
|
+
ds.title = key.to_s
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
end
|
metadata
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: classic_bandit
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.1.
|
|
4
|
+
version: 0.1.1
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Kohei Tsuyuki
|
|
8
8
|
autorequire:
|
|
9
9
|
bindir: exe
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date: 2024-12-
|
|
11
|
+
date: 2024-12-28 00:00:00.000000000 Z
|
|
12
12
|
dependencies: []
|
|
13
13
|
description: Implementation of classic multi-armed bandit algorithms in Ruby. Supports
|
|
14
14
|
Thompson Sampling, UCB1, and Epsilon-Greedy strategies with a simple, consistent
|
|
@@ -26,6 +26,10 @@ files:
|
|
|
26
26
|
- LICENSE.txt
|
|
27
27
|
- README.md
|
|
28
28
|
- Rakefile
|
|
29
|
+
- example/Gemfile
|
|
30
|
+
- example/Gemfile.lock
|
|
31
|
+
- example/beta_random.rb
|
|
32
|
+
- example/simulation.rb
|
|
29
33
|
- lib/classic_bandit.rb
|
|
30
34
|
- lib/classic_bandit/arm.rb
|
|
31
35
|
- lib/classic_bandit/arm_updatable.rb
|