classic_bandit 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9ce78adca73218db90a909daec3ccbb457417620d78ae2566c0a32ac7024a313
4
- data.tar.gz: 9b6debd514866747242bd00de9db4d97d1fd95e5a03166e52c96b812daa3d7f6
3
+ metadata.gz: adf0483eb642e9d50a1265baeb748665a45209a1bea1ed1e101a134936236faa
4
+ data.tar.gz: 3f4cff2ed59733a2694af2387f91a0fccfdbe30e26d9879984533815abc7b3fd
5
5
  SHA512:
6
- metadata.gz: 80cc50641016e81853f766645d1fe9baf3c7bf4b09c8bc9303ce083dac1bec574d78432dd174953785e3dd37032e34c2248d22f47acd1c08c02078f02ad6f8a7
7
- data.tar.gz: e5b534ad79cc2a91a95b617c315335cb8c399c7241308750d868b4eb13e60016382765ff1e83447495676414e1213a6b7212d31e9b18f2859063d5f491f80b25
6
+ metadata.gz: f1abb6ea6d2c7aa56648242e318ae6ce02e5047cf13fee0dc01412030716629fef088f44b849bce50a45590f9e284a80230a3a5289b9fa1f4c65267bdaefc838
7
+ data.tar.gz: 49abf9eba86caefb27b3f4fd067d9aa20865e9bac066baca682a01ee34e3273883cf9113f36ebe6f47376c029377fd69749e3b7bd9ba477468a253ce9df5b422
data/.rubocop.yml CHANGED
@@ -1,5 +1,8 @@
1
1
  AllCops:
2
2
  TargetRubyVersion: 3.0
3
+ Include:
4
+ - 'lib/**/*'
5
+ - 'spec/**/*'
3
6
 
4
7
  Style/StringLiterals:
5
8
  EnforcedStyle: double_quotes
data/Rakefile CHANGED
@@ -7,6 +7,8 @@ RSpec::Core::RakeTask.new(:spec)
7
7
 
8
8
  require "rubocop/rake_task"
9
9
 
10
- RuboCop::RakeTask.new
10
+ RuboCop::RakeTask.new do |task|
11
+ task.options = ["--config", ".rubocop.yml"] # 明示的に設定ファイルを指定
12
+ end
11
13
 
12
14
  task default: %i[spec rubocop]
data/example/Gemfile ADDED
@@ -0,0 +1,12 @@
1
+ # frozen_string_literal: true
2
+
3
+ source "https://rubygems.org"
4
+
5
+ # gem "rails"
6
+
7
+ gem "gnuplot", "~> 2.6"
8
+ gem "matrix", "~> 0.4.2"
9
+
10
+ gem "zeitwerk", "~> 2.7"
11
+
12
+ gem "classic_bandit", "~> 0.1.0"
@@ -0,0 +1,20 @@
1
+ GEM
2
+ remote: https://rubygems.org/
3
+ specs:
4
+ classic_bandit (0.1.0)
5
+ gnuplot (2.6.2)
6
+ matrix (0.4.2)
7
+ zeitwerk (2.7.1)
8
+
9
+ PLATFORMS
10
+ ruby
11
+ x86_64-linux
12
+
13
+ DEPENDENCIES
14
+ classic_bandit (~> 0.1.0)
15
+ gnuplot (~> 2.6)
16
+ matrix (~> 0.4.2)
17
+ zeitwerk (~> 2.7)
18
+
19
+ BUNDLED WITH
20
+ 2.6.2
@@ -0,0 +1,81 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "gnuplot"
4
+
5
+ def gamma_random(alpha)
6
+ return gamma_random(alpha + 1) * rand**(1.0 / alpha) if alpha < 1
7
+
8
+ # Marsaglia-Tsang method
9
+ d = alpha - 1.0 / 3
10
+ c = 1.0 / Math.sqrt(9 * d)
11
+
12
+ loop do
13
+ z = normal_random
14
+ v = (1 + c * z)**3
15
+ u = rand
16
+
17
+ return d * v if z > -1.0 / c && Math.log(u) < 0.5 * z * z + d * (1 - v + Math.log(v))
18
+ end
19
+ end
20
+
21
+ def normal_random
22
+ r = Math.sqrt(-2 * Math.log(rand))
23
+ theta = 2 * Math::PI * rand
24
+ r * Math.cos(theta)
25
+ end
26
+
27
+ def beta_function(alpha, beta)
28
+ gamma_alpha = Math.lgamma(alpha)[0]
29
+ gamma_beta = Math.lgamma(beta)[0]
30
+ gamma_apb = Math.lgamma(alpha + beta)[0]
31
+ Math.exp(gamma_alpha + gamma_beta - gamma_apb)
32
+ end
33
+
34
+ def beta_pdf(x, alpha, beta)
35
+ return 0 if x <= 0 || x >= 1
36
+
37
+ x**(alpha - 1) * (1 - x)**(beta - 1) / beta_function(alpha, beta)
38
+ end
39
+
40
+ data = Array.new(10_000) do
41
+ x1 = gamma_random(41)
42
+ x2 = gamma_random(61)
43
+ x1 / (x1 + x2)
44
+ end
45
+
46
+ Gnuplot.open do |gp|
47
+ Gnuplot::Plot.new(gp) do |plot|
48
+ plot.title "Beta distribution histogram"
49
+ plot.xlabel "Value"
50
+ plot.ylabel "Frequency"
51
+
52
+ min_val = 0.0
53
+ max_val = 1.0
54
+ bin_count = 60.0
55
+ bin_width = (max_val - min_val) / bin_count
56
+
57
+ plot.xrange "[0:1]"
58
+ total_count = data.length.to_f
59
+
60
+ plot.set "style data histograms"
61
+ plot.set "style fill solid 0.5"
62
+
63
+ bins = Hash.new(0)
64
+ bin_count.to_i.times.each { |i| bins[i * bin_width] = 0 }
65
+ data.each { |v| bins[(v / bin_width).floor * bin_width] += 1 }
66
+ bins.transform_values! { |v| v / (total_count * bin_width) }
67
+
68
+ plot.data << Gnuplot::DataSet.new([bins.keys, bins.values]) do |ds|
69
+ ds.with = "boxes"
70
+ ds.title = "Empirical"
71
+ end
72
+
73
+ x_points = (0..100).map { |i| i / 100.0 }
74
+ y_points = x_points.map { |x| beta_pdf(x, 41, 61) }
75
+ plot.data << Gnuplot::DataSet.new([x_points, y_points]) do |ds|
76
+ ds.with = "lines"
77
+ ds.linewidth = 2
78
+ ds.title = "Theoretical PDF"
79
+ end
80
+ end
81
+ end
@@ -0,0 +1,73 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'classic_bandit'
4
+ require 'gnuplot'
5
+
6
+ bandits = {
7
+ "UCB1" => ClassicBandit::Ucb1.new(arms: [
8
+ ClassicBandit::Arm.new(id: 0, trials: 1000, successes: 120),
9
+ ClassicBandit::Arm.new(id: 1, trials: 1000, successes: 110),
10
+ ClassicBandit::Arm.new(id: 2, trials: 1000, successes: 100),
11
+ ]),
12
+ "Thompson sampling" => ClassicBandit::ThompsonSampling.new(arms: [
13
+ ClassicBandit::Arm.new(id: 0, trials: 1000, successes: 120),
14
+ ClassicBandit::Arm.new(id: 1, trials: 1000, successes: 110),
15
+ ClassicBandit::Arm.new(id: 2, trials: 1000, successes: 100),
16
+ ])
17
+ }
18
+
19
+ arm0_counts = Hash.new(0)
20
+ arm0_probs = {}
21
+ bandits.keys.each { |key| arm0_probs[key] = [] }
22
+ x_values = []
23
+
24
+ 10000.times.each do |i|
25
+ bandits.each do |key, bandit|
26
+ # 最初の500回はランダム
27
+ if i < 500
28
+ arm = bandit.arms.sample
29
+ else
30
+ arm = bandit.select_arm
31
+ end
32
+ reward = rand <= arm.mean_reward ? 1 : 0
33
+ bandit.update(arm, reward)
34
+
35
+ if arm.id == 0
36
+ arm0_counts[key] += 1
37
+ end
38
+
39
+ arm0_prob = arm0_counts[key].to_f / (i + 1)
40
+ arm0_probs[key] << arm0_prob
41
+ end
42
+
43
+ x_values << i + 1
44
+ end
45
+
46
+ Gnuplot.open do |gp|
47
+ Gnuplot::Plot.new(gp) do |plot|
48
+ plot.title "Bandit Selection Probability"
49
+ plot.xlabel "Iterations"
50
+ plot.ylabel "Probability"
51
+
52
+ # y軸の範囲を0-1に設定
53
+ plot.yrange "[0:1]"
54
+
55
+ # グリッドを表示
56
+ plot.set "grid"
57
+
58
+ # 線のスタイルを設定
59
+ plot.set "style line 1 linecolor rgb '#0060ad' linewidth 2"
60
+ plot.set "style line 2 linecolor rgb '#dd181f' linewidth 2"
61
+
62
+ # 各アルゴリズムのデータをプロット
63
+ colors = ["#0060ad", "#dd181f"]
64
+ bandits.each_with_index do |(key, _), index|
65
+ plot.data << Gnuplot::DataSet.new([x_values, arm0_probs[key]]) do |ds|
66
+ ds.with = "lines"
67
+ ds.linewidth = 2
68
+ ds.linecolor = "rgb '#{colors[index]}'"
69
+ ds.title = key.to_s
70
+ end
71
+ end
72
+ end
73
+ end
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module ClassicBandit
4
- VERSION = "0.1.0"
4
+ VERSION = "0.1.1"
5
5
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: classic_bandit
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Kohei Tsuyuki
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2024-12-27 00:00:00.000000000 Z
11
+ date: 2024-12-28 00:00:00.000000000 Z
12
12
  dependencies: []
13
13
  description: Implementation of classic multi-armed bandit algorithms in Ruby. Supports
14
14
  Thompson Sampling, UCB1, and Epsilon-Greedy strategies with a simple, consistent
@@ -26,6 +26,10 @@ files:
26
26
  - LICENSE.txt
27
27
  - README.md
28
28
  - Rakefile
29
+ - example/Gemfile
30
+ - example/Gemfile.lock
31
+ - example/beta_random.rb
32
+ - example/simulation.rb
29
33
  - lib/classic_bandit.rb
30
34
  - lib/classic_bandit/arm.rb
31
35
  - lib/classic_bandit/arm_updatable.rb