classic_bandit 0.1.0 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rubocop.yml +8 -1
- data/CLAUDE.md +45 -0
- data/Rakefile +3 -1
- data/example/Gemfile +12 -0
- data/example/Gemfile.lock +20 -0
- data/example/beta_random.rb +87 -0
- data/example/simulation.rb +73 -0
- data/lib/classic_bandit/thompson_sampling.rb +31 -11
- data/lib/classic_bandit/version.rb +1 -1
- metadata +10 -5
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 9e7b77e4f21e909a3515df0c5785b2f321ce3cf2977d825f0190adcd9f15a82d
|
|
4
|
+
data.tar.gz: daeb7c6386d836256e0a32c65fc72e68ad02ad2283b4a6e89b8a4eb9eebf35ab
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 7ca36345f798cc2e869cba8e0c3e41b33892771227fcd12c85c2d9973d6fb697a79b22366f6f59a7bb041772e1d4688147162b78f0d2530ec9d4dd9da5e45981
|
|
7
|
+
data.tar.gz: 1bea576eeb3b00795b40206279376fd9b1da678cfd9ce5157829d1be8983fcb7fa7f488d91733715d263e772de250eb1a79145790152e99cb771515453db3fbe
|
data/.rubocop.yml
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
AllCops:
|
|
2
2
|
TargetRubyVersion: 3.0
|
|
3
|
+
NewCops: disable
|
|
4
|
+
Include:
|
|
5
|
+
- 'lib/**/*'
|
|
6
|
+
- 'spec/**/*'
|
|
3
7
|
|
|
4
8
|
Style/StringLiterals:
|
|
5
9
|
EnforcedStyle: double_quotes
|
|
@@ -13,4 +17,7 @@ Metrics/BlockLength:
|
|
|
13
17
|
|
|
14
18
|
Style/Documentation:
|
|
15
19
|
Exclude:
|
|
16
|
-
- '**/*'
|
|
20
|
+
- '**/*'
|
|
21
|
+
|
|
22
|
+
Metrics/MethodLength:
|
|
23
|
+
Max: 30
|
data/CLAUDE.md
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# CLAUDE.md
|
|
2
|
+
|
|
3
|
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
|
4
|
+
|
|
5
|
+
## Development Commands
|
|
6
|
+
|
|
7
|
+
- **Run tests**: `bundle exec rspec`
|
|
8
|
+
- **Lint code**: `bundle exec rubocop`
|
|
9
|
+
- **Run all checks**: `bundle exec rake` (runs both tests and linting)
|
|
10
|
+
- **Install dependencies**: `bundle install`
|
|
11
|
+
|
|
12
|
+
The project uses RuboCop for linting with explicit config file reference in Rakefile:15.
|
|
13
|
+
|
|
14
|
+
## Architecture Overview
|
|
15
|
+
|
|
16
|
+
ClassicBandit is a Ruby gem implementing multi-armed bandit algorithms for A/B testing and optimization. The library uses Zeitwerk for autoloading and follows a modular design pattern.
|
|
17
|
+
|
|
18
|
+
### Core Components
|
|
19
|
+
|
|
20
|
+
**Arm (`lib/classic_bandit/arm.rb`)**: Represents a bandit arm with trial/success tracking and mean reward calculation.
|
|
21
|
+
|
|
22
|
+
**ArmUpdatable (`lib/classic_bandit/arm_updatable.rb`)**: Shared module providing `update(arm, reward)` method for all bandit algorithms. Validates rewards are 0 or 1.
|
|
23
|
+
|
|
24
|
+
**Algorithm Implementations**:
|
|
25
|
+
- **EpsilonGreedy**: Simple ε-greedy with exploration/exploitation balance
|
|
26
|
+
- **UCB1**: Upper Confidence Bound without explicit parameters
|
|
27
|
+
- **Softmax**: Temperature-based Boltzmann distribution selection
|
|
28
|
+
- **ThompsonSampling**: Bayesian approach with Beta-Bernoulli model using custom Gamma random number generation
|
|
29
|
+
|
|
30
|
+
### Key Design Patterns
|
|
31
|
+
|
|
32
|
+
All algorithms implement the same interface:
|
|
33
|
+
- `select_arm()` → returns an Arm instance
|
|
34
|
+
- `update(arm, reward)` → updates arm statistics (from ArmUpdatable)
|
|
35
|
+
- Handle untested arms by random selection
|
|
36
|
+
|
|
37
|
+
The Thompson Sampling implementation includes custom statistical functions (`gamma_random`, `normal_random`) using Marsaglia-Tsang method for Gamma distribution sampling.
|
|
38
|
+
|
|
39
|
+
### Testing Structure
|
|
40
|
+
|
|
41
|
+
Tests are organized in `spec/` with individual algorithm specs and a main gem spec. Uses standard RSpec testing framework.
|
|
42
|
+
|
|
43
|
+
### Example Usage
|
|
44
|
+
|
|
45
|
+
The `example/` directory contains simulation scripts demonstrating algorithm comparison with Gnuplot visualization, showing realistic usage patterns with pre-populated arm statistics.
|
data/Rakefile
CHANGED
data/example/Gemfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
GEM
|
|
2
|
+
remote: https://rubygems.org/
|
|
3
|
+
specs:
|
|
4
|
+
classic_bandit (0.1.0)
|
|
5
|
+
gnuplot (2.6.2)
|
|
6
|
+
matrix (0.4.2)
|
|
7
|
+
zeitwerk (2.7.1)
|
|
8
|
+
|
|
9
|
+
PLATFORMS
|
|
10
|
+
ruby
|
|
11
|
+
x86_64-linux
|
|
12
|
+
|
|
13
|
+
DEPENDENCIES
|
|
14
|
+
classic_bandit (~> 0.1.0)
|
|
15
|
+
gnuplot (~> 2.6)
|
|
16
|
+
matrix (~> 0.4.2)
|
|
17
|
+
zeitwerk (~> 2.7)
|
|
18
|
+
|
|
19
|
+
BUNDLED WITH
|
|
20
|
+
2.6.2
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "gnuplot"
|
|
4
|
+
|
|
5
|
+
def gamma_random(alpha) # rubocop:disable Metrics/AbcSize
|
|
6
|
+
return gamma_random(alpha + 1) * rand**(1.0 / alpha) if alpha < 1
|
|
7
|
+
|
|
8
|
+
d = alpha - 1.0 / 3
|
|
9
|
+
c = 1.0 / Math.sqrt(9 * d)
|
|
10
|
+
|
|
11
|
+
loop do
|
|
12
|
+
x = normal_random
|
|
13
|
+
v = (1 + c * x)**3
|
|
14
|
+
|
|
15
|
+
next if v <= 0
|
|
16
|
+
|
|
17
|
+
u = rand
|
|
18
|
+
|
|
19
|
+
# Squeeze test
|
|
20
|
+
return d * v if u < 1 - 0.0331 * x**4
|
|
21
|
+
|
|
22
|
+
# Full test
|
|
23
|
+
return d * v if Math.log(u) < 0.5 * x * x + d - d * v + d * Math.log(v)
|
|
24
|
+
end
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def normal_random
|
|
28
|
+
r = Math.sqrt(-2 * Math.log(rand))
|
|
29
|
+
theta = 2 * Math::PI * rand
|
|
30
|
+
r * Math.cos(theta)
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
def beta_function(alpha, beta)
|
|
34
|
+
gamma_alpha = Math.lgamma(alpha)[0]
|
|
35
|
+
gamma_beta = Math.lgamma(beta)[0]
|
|
36
|
+
gamma_apb = Math.lgamma(alpha + beta)[0]
|
|
37
|
+
Math.exp(gamma_alpha + gamma_beta - gamma_apb)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def beta_pdf(x, alpha, beta)
|
|
41
|
+
return 0 if x <= 0 || x >= 1
|
|
42
|
+
|
|
43
|
+
x**(alpha - 1) * (1 - x)**(beta - 1) / beta_function(alpha, beta)
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
data = Array.new(10_000) do
|
|
47
|
+
x1 = gamma_random(41)
|
|
48
|
+
x2 = gamma_random(61)
|
|
49
|
+
x1 / (x1 + x2)
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
Gnuplot.open do |gp|
|
|
53
|
+
Gnuplot::Plot.new(gp) do |plot|
|
|
54
|
+
plot.title "Beta distribution histogram"
|
|
55
|
+
plot.xlabel "Value"
|
|
56
|
+
plot.ylabel "Frequency"
|
|
57
|
+
|
|
58
|
+
min_val = 0.0
|
|
59
|
+
max_val = 1.0
|
|
60
|
+
bin_count = 60.0
|
|
61
|
+
bin_width = (max_val - min_val) / bin_count
|
|
62
|
+
|
|
63
|
+
plot.xrange "[0:1]"
|
|
64
|
+
total_count = data.length.to_f
|
|
65
|
+
|
|
66
|
+
plot.set "style data histograms"
|
|
67
|
+
plot.set "style fill solid 0.5"
|
|
68
|
+
|
|
69
|
+
bins = Hash.new(0)
|
|
70
|
+
bin_count.to_i.times.each { |i| bins[i * bin_width] = 0 }
|
|
71
|
+
data.each { |v| bins[(v / bin_width).floor * bin_width] += 1 }
|
|
72
|
+
bins.transform_values! { |v| v / (total_count * bin_width) }
|
|
73
|
+
|
|
74
|
+
plot.data << Gnuplot::DataSet.new([bins.keys, bins.values]) do |ds|
|
|
75
|
+
ds.with = "boxes"
|
|
76
|
+
ds.title = "Empirical"
|
|
77
|
+
end
|
|
78
|
+
|
|
79
|
+
x_points = (0..100).map { |i| i / 100.0 }
|
|
80
|
+
y_points = x_points.map { |x| beta_pdf(x, 41, 61) }
|
|
81
|
+
plot.data << Gnuplot::DataSet.new([x_points, y_points]) do |ds|
|
|
82
|
+
ds.with = "lines"
|
|
83
|
+
ds.linewidth = 2
|
|
84
|
+
ds.title = "Theoretical PDF"
|
|
85
|
+
end
|
|
86
|
+
end
|
|
87
|
+
end
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'classic_bandit'
|
|
4
|
+
require 'gnuplot'
|
|
5
|
+
|
|
6
|
+
bandits = {
|
|
7
|
+
"UCB1" => ClassicBandit::Ucb1.new(arms: [
|
|
8
|
+
ClassicBandit::Arm.new(id: 0, trials: 1000, successes: 120),
|
|
9
|
+
ClassicBandit::Arm.new(id: 1, trials: 1000, successes: 110),
|
|
10
|
+
ClassicBandit::Arm.new(id: 2, trials: 1000, successes: 100),
|
|
11
|
+
]),
|
|
12
|
+
"Thompson sampling" => ClassicBandit::ThompsonSampling.new(arms: [
|
|
13
|
+
ClassicBandit::Arm.new(id: 0, trials: 1000, successes: 120),
|
|
14
|
+
ClassicBandit::Arm.new(id: 1, trials: 1000, successes: 110),
|
|
15
|
+
ClassicBandit::Arm.new(id: 2, trials: 1000, successes: 100),
|
|
16
|
+
])
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
arm0_counts = Hash.new(0)
|
|
20
|
+
arm0_probs = {}
|
|
21
|
+
bandits.keys.each { |key| arm0_probs[key] = [] }
|
|
22
|
+
x_values = []
|
|
23
|
+
|
|
24
|
+
10000.times.each do |i|
|
|
25
|
+
bandits.each do |key, bandit|
|
|
26
|
+
# 最初の500回はランダム
|
|
27
|
+
if i < 500
|
|
28
|
+
arm = bandit.arms.sample
|
|
29
|
+
else
|
|
30
|
+
arm = bandit.select_arm
|
|
31
|
+
end
|
|
32
|
+
reward = rand <= arm.mean_reward ? 1 : 0
|
|
33
|
+
bandit.update(arm, reward)
|
|
34
|
+
|
|
35
|
+
if arm.id == 0
|
|
36
|
+
arm0_counts[key] += 1
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
arm0_prob = arm0_counts[key].to_f / (i + 1)
|
|
40
|
+
arm0_probs[key] << arm0_prob
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
x_values << i + 1
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
Gnuplot.open do |gp|
|
|
47
|
+
Gnuplot::Plot.new(gp) do |plot|
|
|
48
|
+
plot.title "Bandit Selection Probability"
|
|
49
|
+
plot.xlabel "Iterations"
|
|
50
|
+
plot.ylabel "Probability"
|
|
51
|
+
|
|
52
|
+
# y軸の範囲を0-1に設定
|
|
53
|
+
plot.yrange "[0:1]"
|
|
54
|
+
|
|
55
|
+
# グリッドを表示
|
|
56
|
+
plot.set "grid"
|
|
57
|
+
|
|
58
|
+
# 線のスタイルを設定
|
|
59
|
+
plot.set "style line 1 linecolor rgb '#0060ad' linewidth 2"
|
|
60
|
+
plot.set "style line 2 linecolor rgb '#dd181f' linewidth 2"
|
|
61
|
+
|
|
62
|
+
# 各アルゴリズムのデータをプロット
|
|
63
|
+
colors = ["#0060ad", "#dd181f"]
|
|
64
|
+
bandits.each_with_index do |(key, _), index|
|
|
65
|
+
plot.data << Gnuplot::DataSet.new([x_values, arm0_probs[key]]) do |ds|
|
|
66
|
+
ds.with = "lines"
|
|
67
|
+
ds.linewidth = 2
|
|
68
|
+
ds.linecolor = "rgb '#{colors[index]}'"
|
|
69
|
+
ds.title = key.to_s
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
end
|
|
@@ -14,26 +14,39 @@ module ClassicBandit
|
|
|
14
14
|
class ThompsonSampling
|
|
15
15
|
include ArmUpdatable
|
|
16
16
|
|
|
17
|
-
attr_reader :arms
|
|
17
|
+
attr_reader :arms, :alpha_prior, :beta_prior
|
|
18
|
+
|
|
19
|
+
# @param arms [Array<Arm>] Array of arms to choose from
|
|
20
|
+
# @param alpha_prior [Float] Prior parameter for successes (default: 1.0)
|
|
21
|
+
# @param beta_prior [Float] Prior parameter for failures (default: 1.0)
|
|
22
|
+
def initialize(arms:, alpha_prior: 1.0, beta_prior: 1.0)
|
|
23
|
+
raise ArgumentError, "alpha_prior must be positive" unless alpha_prior.positive?
|
|
24
|
+
raise ArgumentError, "beta_prior must be positive" unless beta_prior.positive?
|
|
18
25
|
|
|
19
|
-
def initialize(arms:)
|
|
20
26
|
@arms = arms
|
|
27
|
+
@alpha_prior = alpha_prior
|
|
28
|
+
@beta_prior = beta_prior
|
|
21
29
|
end
|
|
22
30
|
|
|
23
31
|
def select_arm
|
|
24
|
-
return @arms.sample if @arms.all? { |arm| arm.trials.zero? }
|
|
25
|
-
|
|
26
32
|
@arms.max_by { |arm| ts_score(arm) }
|
|
27
33
|
end
|
|
28
34
|
|
|
29
35
|
private
|
|
30
36
|
|
|
31
37
|
def ts_score(arm)
|
|
32
|
-
|
|
33
|
-
|
|
38
|
+
alpha = arm.successes + @alpha_prior
|
|
39
|
+
beta = (arm.trials - arm.successes) + @beta_prior
|
|
40
|
+
|
|
41
|
+
beta_sample(alpha, beta)
|
|
42
|
+
end
|
|
34
43
|
|
|
35
|
-
|
|
36
|
-
|
|
44
|
+
def beta_sample(alpha, beta)
|
|
45
|
+
# Beta(1,1) = Uniform(0,1)
|
|
46
|
+
return rand if alpha == 1.0 && beta == 1.0 # rubocop:disable Lint/FloatComparison
|
|
47
|
+
|
|
48
|
+
x = gamma_random(alpha)
|
|
49
|
+
y = gamma_random(beta)
|
|
37
50
|
x / (x + y)
|
|
38
51
|
end
|
|
39
52
|
|
|
@@ -45,11 +58,18 @@ module ClassicBandit
|
|
|
45
58
|
c = 1.0 / Math.sqrt(9 * d)
|
|
46
59
|
|
|
47
60
|
loop do
|
|
48
|
-
|
|
49
|
-
v = (1 + c *
|
|
61
|
+
x = normal_random
|
|
62
|
+
v = (1 + c * x)**3
|
|
63
|
+
|
|
64
|
+
next if v <= 0
|
|
65
|
+
|
|
50
66
|
u = rand
|
|
51
67
|
|
|
52
|
-
|
|
68
|
+
# Squeeze test
|
|
69
|
+
return d * v if u < 1 - 0.0331 * x**4
|
|
70
|
+
|
|
71
|
+
# Full test
|
|
72
|
+
return d * v if Math.log(u) < 0.5 * x * x + d - d * v + d * Math.log(v)
|
|
53
73
|
end
|
|
54
74
|
end
|
|
55
75
|
|
metadata
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: classic_bandit
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.1.
|
|
4
|
+
version: 0.1.2
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Kohei Tsuyuki
|
|
8
|
-
autorequire:
|
|
8
|
+
autorequire:
|
|
9
9
|
bindir: exe
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date:
|
|
11
|
+
date: 2025-11-09 00:00:00.000000000 Z
|
|
12
12
|
dependencies: []
|
|
13
13
|
description: Implementation of classic multi-armed bandit algorithms in Ruby. Supports
|
|
14
14
|
Thompson Sampling, UCB1, and Epsilon-Greedy strategies with a simple, consistent
|
|
@@ -23,9 +23,14 @@ files:
|
|
|
23
23
|
- ".rspec"
|
|
24
24
|
- ".rubocop.yml"
|
|
25
25
|
- CHANGELOG.md
|
|
26
|
+
- CLAUDE.md
|
|
26
27
|
- LICENSE.txt
|
|
27
28
|
- README.md
|
|
28
29
|
- Rakefile
|
|
30
|
+
- example/Gemfile
|
|
31
|
+
- example/Gemfile.lock
|
|
32
|
+
- example/beta_random.rb
|
|
33
|
+
- example/simulation.rb
|
|
29
34
|
- lib/classic_bandit.rb
|
|
30
35
|
- lib/classic_bandit/arm.rb
|
|
31
36
|
- lib/classic_bandit/arm_updatable.rb
|
|
@@ -42,7 +47,7 @@ metadata:
|
|
|
42
47
|
homepage_uri: https://github.com/t-chov/classic_bandit
|
|
43
48
|
source_code_uri: https://github.com/t-chov/classic_bandit
|
|
44
49
|
changelog_uri: https://github.com/t-chov/classic_bandit/blob/main/CHANGELOG.md
|
|
45
|
-
post_install_message:
|
|
50
|
+
post_install_message:
|
|
46
51
|
rdoc_options: []
|
|
47
52
|
require_paths:
|
|
48
53
|
- lib
|
|
@@ -58,7 +63,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
|
58
63
|
version: '0'
|
|
59
64
|
requirements: []
|
|
60
65
|
rubygems_version: 3.2.33
|
|
61
|
-
signing_key:
|
|
66
|
+
signing_key:
|
|
62
67
|
specification_version: 4
|
|
63
68
|
summary: A Ruby library for classic (non-contextual) multi-armed bandit algorithms
|
|
64
69
|
test_files: []
|