ruby-em_algorithm 0.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/Gemfile +6 -0
- data/Gemfile.lock +30 -0
- data/README.md +44 -0
- data/Rakefile +7 -0
- data/example/.ex1.rb.swp +0 -0
- data/example/.ex2.rb.swp +0 -0
- data/example/.ex3-tmp.rb.swp +0 -0
- data/example/.ex3.rb.swp +0 -0
- data/example/data/2dim-gmm-new.txt +1267 -0
- data/example/data/2dim-gmm-simple.txt +676 -0
- data/example/data/2dim-gmm-test.txt +6565 -0
- data/example/data/2dim-gmm-test2.txt +2782 -0
- data/example/data/2dim-gmm-test3.csv +1641 -0
- data/example/data/2dim-gmm-test3.txt +2782 -0
- data/example/data/2dim-gmm-test4.csv +868 -0
- data/example/data/2dim-gmm-test4.txt +4924 -0
- data/example/data/2dim-gmm-without_weight-small.txt +2401 -0
- data/example/data/2dim-gmm-without_weight.txt +18001 -0
- data/example/data/2dim-gmm.txt +1267 -0
- data/example/data/gmm-new.txt +10001 -0
- data/example/data/gmm-simple.txt +676 -0
- data/example/data/gmm.txt +10001 -0
- data/example/data/old-gmm.txt +10000 -0
- data/example/ex1.rb +20 -0
- data/example/ex1.rb~ +20 -0
- data/example/ex2.rb +33 -0
- data/example/ex2.rb~ +33 -0
- data/example/ex3-tmp.rb +23 -0
- data/example/ex3-tmp.rb~ +25 -0
- data/example/ex3.rb +43 -0
- data/example/ex3.rb~ +43 -0
- data/example/tools/.2dim.rb.swp +0 -0
- data/example/tools/2dim.rb +69 -0
- data/example/tools/2dim.rb~ +69 -0
- data/example/tools/boxmuller.rb +28 -0
- data/example/tools/boxmuller.rb~ +28 -0
- data/example/tools/conv_from_yaml.rb +8 -0
- data/example/tools/conv_from_yaml_to_csv.rb +8 -0
- data/example/tools/conv_to_yaml.rb +17 -0
- data/example/tools/ellipsoid.gnuplot +63 -0
- data/example/tools/ellipsoid.gnuplot~ +64 -0
- data/example/tools/histogram.rb +19 -0
- data/example/tools/histogram2d.rb +20 -0
- data/example/tools/histogram2d.rb~ +18 -0
- data/example/tools/kmeans.rb +34 -0
- data/example/tools/mean.rb +19 -0
- data/example/tools/table.data +4618 -0
- data/example/tools/tmp.txt +69632 -0
- data/example/tools/xmeans.R +608 -0
- data/example/tools/xmeans.rb +35 -0
- data/lib/em_algorithm/.base.rb.swp +0 -0
- data/lib/em_algorithm/base.rb +116 -0
- data/lib/em_algorithm/base.rb~ +116 -0
- data/lib/em_algorithm/convergence/.chi_square.rb.swp +0 -0
- data/lib/em_algorithm/convergence/.likelihood.rb.swp +0 -0
- data/lib/em_algorithm/convergence/check_method.rb +4 -0
- data/lib/em_algorithm/convergence/check_method.rb~ +0 -0
- data/lib/em_algorithm/convergence/chi_square.rb +40 -0
- data/lib/em_algorithm/convergence/chi_square.rb~ +40 -0
- data/lib/em_algorithm/convergence/likelihood.rb +35 -0
- data/lib/em_algorithm/convergence/likelihood.rb~ +35 -0
- data/lib/em_algorithm/models/.gaussian.rb.swp +0 -0
- data/lib/em_algorithm/models/.md_gaussian.rb.swp +0 -0
- data/lib/em_algorithm/models/.mixture.rb.swp +0 -0
- data/lib/em_algorithm/models/.model.rb.swp +0 -0
- data/lib/em_algorithm/models/gaussian.rb +47 -0
- data/lib/em_algorithm/models/gaussian.rb~ +47 -0
- data/lib/em_algorithm/models/md_gaussian.rb +67 -0
- data/lib/em_algorithm/models/md_gaussian.rb~ +67 -0
- data/lib/em_algorithm/models/mixture.rb +122 -0
- data/lib/em_algorithm/models/mixture.rb~ +122 -0
- data/lib/em_algorithm/models/model.rb +19 -0
- data/lib/em_algorithm/models/model.rb~ +19 -0
- data/lib/ruby-em_algorithm.rb +3 -0
- data/lib/ruby-em_algorithm/version.rb +3 -0
- data/ruby-em_algorithm.gemspec +21 -0
- data/spec/spec_helper.rb +9 -0
- metadata +178 -0
@@ -0,0 +1,47 @@
|
|
1
|
+
module EMAlgorithm
|
2
|
+
class Gaussian < Model
|
3
|
+
attr_accessor :mu, :sigma, :dim
|
4
|
+
|
5
|
+
def initialize(mu = 0.0, sigma = 1.0)
|
6
|
+
@mu = mu
|
7
|
+
@sigma = sigma
|
8
|
+
end
|
9
|
+
|
10
|
+
def probability_density_function(x)
|
11
|
+
exp(-((x-@mu)**2.0)/(2.0*@sigma**2))/(sqrt(2.0*PI)*@sigma)
|
12
|
+
end
|
13
|
+
|
14
|
+
def probability_density_function_with_observation_weight(x_with_weight)
|
15
|
+
x = x_with_weight[0]
|
16
|
+
observation_weight = x_with_weight[1]
|
17
|
+
observation_weight * probability_density_function(x)
|
18
|
+
end
|
19
|
+
|
20
|
+
def update_average!(data_array, temp_weight, temp_weight_per_datum)
|
21
|
+
data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
|
22
|
+
sum + temp_weight_per_datum[di] * data_array[di]
|
23
|
+
end
|
24
|
+
@mu = data_sum / temp_weight
|
25
|
+
end
|
26
|
+
|
27
|
+
def update_sigma!(data_array, temp_weight, temp_weight_per_datum)
|
28
|
+
data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
|
29
|
+
sum + temp_weight_per_datum[di] * (data_array[di] - @mu) ** 2
|
30
|
+
end
|
31
|
+
@sigma = sqrt(data_sum / temp_weight)
|
32
|
+
end
|
33
|
+
|
34
|
+
def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
|
35
|
+
update_average!(data_array, temp_weight, temp_weight_per_datum)
|
36
|
+
update_sigma!(data_array, temp_weight, temp_weight_per_datum)
|
37
|
+
end
|
38
|
+
|
39
|
+
def to_gnuplot
|
40
|
+
"exp(-((x-(#{@mu.round(DIGIT)}))**2.0)/(2.0*#{(@sigma ** 2).round(DIGIT)}))/(sqrt(2.0*pi)*#{@sigma.round(DIGIT)})"
|
41
|
+
end
|
42
|
+
|
43
|
+
def to_gnuplot_with_title(weight)
|
44
|
+
to_gnuplot + " w l axis x1y2 lw 3 title '#{weight.round(DIGIT)}*N(#{@mu.round(DIGIT)},#{@sigma.round(DIGIT)})'"
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
@@ -0,0 +1,47 @@
|
|
1
|
+
module EMAlgorithm
|
2
|
+
class Gaussian < Model
|
3
|
+
attr_accessor :mu, :sigma, :dim
|
4
|
+
|
5
|
+
def initialize(mu = 0.0, sigma = 1.0)
|
6
|
+
@mu = mu
|
7
|
+
@sigma = sigma
|
8
|
+
end
|
9
|
+
|
10
|
+
def probability_density_function(x)
|
11
|
+
exp(-((x-@mu)**2.0)/(2.0*@sigma**2))/(sqrt(2.0*PI)*@sigma)
|
12
|
+
end
|
13
|
+
|
14
|
+
def probability_density_function_with_observation_weight(x_with_weight)
|
15
|
+
x = x_with_weight[0]
|
16
|
+
observation_weight = x_with_weight[1]
|
17
|
+
observation_weight * probability_density_function(x)
|
18
|
+
end
|
19
|
+
|
20
|
+
def update_average!(data_array, temp_weight, temp_weight_per_datum)
|
21
|
+
data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
|
22
|
+
sum + temp_weight_per_datum[di] * data_array[di]
|
23
|
+
end
|
24
|
+
@mu = data_sum / temp_weight
|
25
|
+
end
|
26
|
+
|
27
|
+
def update_sigma!(data_array, temp_weight, temp_weight_per_datum)
|
28
|
+
data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
|
29
|
+
sum + temp_weight_per_datum[di] * (data_array[di] - @mu) ** 2
|
30
|
+
end
|
31
|
+
@sigma = sqrt(data_sum / temp_weight)
|
32
|
+
end
|
33
|
+
|
34
|
+
def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
|
35
|
+
update_average!(data_array, temp_weight, temp_weight_per_datum)
|
36
|
+
update_sigma!(data_array, temp_weight, temp_weight_per_datum)
|
37
|
+
end
|
38
|
+
|
39
|
+
def to_gnuplot
|
40
|
+
"exp(-((x-(#{@mu.round(DIGIT)}))**2.0)/(2.0*#{(@sigma ** 2).round(DIGIT)}))/(sqrt(2.0*pi)*#{@sigma.round(DIGIT)})"
|
41
|
+
end
|
42
|
+
|
43
|
+
def to_gnuplot_with_title(weight)
|
44
|
+
to_gnuplot + " w l axis x1y2 lw 3 title '#{weight.round(DIGIT)}*N(#{@mu.round(DIGIT)},#{@sigma.round(DIGIT)})'"
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
@@ -0,0 +1,67 @@
|
|
1
|
+
module EMAlgorithm
|
2
|
+
class MdGaussian < Model
|
3
|
+
attr_accessor :mu, :sigma2
|
4
|
+
|
5
|
+
def initialize(mu = GSL::Vector[0.0, 0.0], sigma2 = GSL::Matrix[[1.0, 0.0], [0.0, 1.0]])
|
6
|
+
# check mu
|
7
|
+
if mu.class != GSL::Vector
|
8
|
+
raise ArgumentError, "mu should be GSL::Vector."
|
9
|
+
end
|
10
|
+
@mu = mu
|
11
|
+
# check sigma2
|
12
|
+
if sigma2.class != GSL::Matrix
|
13
|
+
raise ArgumentError, "sigma2 should be GSL::Matrix."
|
14
|
+
elsif sigma2.size1 != @mu.size || sigma2.size2 != @mu.size
|
15
|
+
raise ArgumentError, "The size of sigma2 matrix does not match with mu vector."
|
16
|
+
end
|
17
|
+
@sigma2 = sigma2
|
18
|
+
@sqrt_sigma2_det = sqrt(@sigma2.det)
|
19
|
+
@sigma2_invert = @sigma2.invert
|
20
|
+
end
|
21
|
+
|
22
|
+
def probability_density_function(x)
|
23
|
+
exp(-((x-@mu) * @sigma2_invert * (x-@mu).trans)/2.0)/((sqrt(2.0*PI)**@mu.size)*@sqrt_sigma2_det)
|
24
|
+
end
|
25
|
+
|
26
|
+
def update_average!(data_array, temp_weight, temp_weight_per_datum)
|
27
|
+
data_sum = (0..(data_array.size-1)).inject(GSL::Vector.alloc(@mu.size).set_zero) do |sum, di|
|
28
|
+
sum + temp_weight_per_datum[di] * data_array[di]
|
29
|
+
end
|
30
|
+
@mu = data_sum / temp_weight
|
31
|
+
end
|
32
|
+
|
33
|
+
def update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
|
34
|
+
data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
|
35
|
+
sum + temp_weight_per_datum[di] * (data_array[di] - @mu).trans * (data_array[di] - @mu)
|
36
|
+
end
|
37
|
+
@sigma2 = (data_sum / temp_weight)
|
38
|
+
@sqrt_sigma2_det = sqrt(@sigma2.det)
|
39
|
+
@sigma2_invert = @sigma2.invert
|
40
|
+
end
|
41
|
+
|
42
|
+
def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
|
43
|
+
update_average!(data_array, temp_weight, temp_weight_per_datum)
|
44
|
+
update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
|
45
|
+
end
|
46
|
+
|
47
|
+
def to_gnuplot
|
48
|
+
if @mu.size == 2
|
49
|
+
# [x - mu_x, y - mu_y] * [[s_x, s_xy], [s_xy, s_y]] * [x - mu_x, y - mu_y]
|
50
|
+
# = s_x*(x - mu_x)**2 + 2*s_xy*(x - mu_x)(y - mu_y) + s_y*(y - mu_y)**2
|
51
|
+
sigma2_xy = @sigma2[0,1] + @sigma2[1,0]
|
52
|
+
xy = "+(#{sigma2_xy.round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))*(y-(#{@mu[1].round((DIGIT))}))" if sigma2_xy > 0 || sigma2_xy < 0
|
53
|
+
"exp(-((#{@sigma2[0,0].round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))**2.0+(#{@sigma2[1,1].round((DIGIT))})*(y-(#{@mu[1].round((DIGIT))}))**2.0#{xy})/2.0)/((sqrt(2.0*pi))**#{@mu.size}*(#{@sqrt_sigma2_det.round((DIGIT))}))"
|
54
|
+
else
|
55
|
+
"N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
def to_gnuplot_with_title(weight)
|
60
|
+
if @mu.size == 2
|
61
|
+
to_gnuplot + " w l lw 3 title '#{weight.round((DIGIT))}*N(#{@mu.map{|mu| mu.round((DIGIT))}.to_a.inspect},#{@sigma2.map{|sigma2| sigma2.round((DIGIT))}.to_a.inspect})'"
|
62
|
+
else
|
63
|
+
"N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
@@ -0,0 +1,67 @@
|
|
1
|
+
module EMAlgorithm
|
2
|
+
class MdGaussian < Model
|
3
|
+
attr_accessor :mu, :sigma2
|
4
|
+
|
5
|
+
def initialize(mu = GSL::Vector[0.0, 0.0], sigma2 = GSL::Matrix[[1.0, 0.0], [0.0, 1.0]])
|
6
|
+
# check mu
|
7
|
+
if mu.class != GSL::Vector
|
8
|
+
raise ArgumentError, "mu should be GSL::Vector."
|
9
|
+
end
|
10
|
+
@mu = mu
|
11
|
+
# check sigma2
|
12
|
+
if sigma2.class != GSL::Matrix
|
13
|
+
raise ArgumentError, "sigma2 should be GSL::Matrix."
|
14
|
+
elsif sigma2.size1 != @mu.size || sigma2.size2 != @mu.size
|
15
|
+
raise ArgumentError, "The size of sigma2 matrix does not match with mu vector."
|
16
|
+
end
|
17
|
+
@sigma2 = sigma2
|
18
|
+
@sqrt_sigma2_det = sqrt(@sigma2.det)
|
19
|
+
@sigma2_invert = @sigma2.invert
|
20
|
+
end
|
21
|
+
|
22
|
+
def probability_density_function(x)
|
23
|
+
exp(-((x-@mu) * @sigma2_invert * (x-@mu).trans)/2.0)/((sqrt(2.0*PI)**@mu.size)*@sqrt_sigma2_det)
|
24
|
+
end
|
25
|
+
|
26
|
+
def update_average!(data_array, temp_weight, temp_weight_per_datum)
|
27
|
+
data_sum = (0..(data_array.size-1)).inject(GSL::Vector.alloc(@mu.size).set_zero) do |sum, di|
|
28
|
+
sum + temp_weight_per_datum[di] * data_array[di]
|
29
|
+
end
|
30
|
+
@mu = data_sum / temp_weight
|
31
|
+
end
|
32
|
+
|
33
|
+
def update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
|
34
|
+
data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
|
35
|
+
sum + temp_weight_per_datum[di] * (data_array[di] - @mu).trans * (data_array[di] - @mu)
|
36
|
+
end
|
37
|
+
@sigma2 = (data_sum / temp_weight)
|
38
|
+
@sqrt_sigma2_det = sqrt(@sigma2.det)
|
39
|
+
@sigma2_invert = @sigma2.invert
|
40
|
+
end
|
41
|
+
|
42
|
+
def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
|
43
|
+
update_average!(data_array, temp_weight, temp_weight_per_datum)
|
44
|
+
update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
|
45
|
+
end
|
46
|
+
|
47
|
+
def to_gnuplot
|
48
|
+
if @mu.size == 2
|
49
|
+
# [x - mu_x, y - mu_y] * [[s_x, s_xy], [s_xy, s_y]] * [x - mu_x, y - mu_y]
|
50
|
+
# = s_x*(x - mu_x)**2 + 2*s_xy*(x - mu_x)(y - mu_y) + s_y*(y - mu_y)**2
|
51
|
+
sigma2_xy = @sigma2[0,1] + @sigma2[1,0]
|
52
|
+
xy = "+(#{sigma2_xy.round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))*(y-(#{@mu[1].round((DIGIT))}))" if sigma2_xy > 0 || sigma2_xy < 0
|
53
|
+
"exp(-((#{@sigma2[0,0].round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))**2.0+(#{@sigma2[1,1].round((DIGIT))})*(y-(#{@mu[1].round((DIGIT))}))**2.0#{xy})/2.0)/((sqrt(2.0*pi))**#{@mu.size}*(#{@sqrt_sigma2_det.round((DIGIT))}))"
|
54
|
+
else
|
55
|
+
"N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
def to_gnuplot_with_title(weight)
|
60
|
+
if @mu.size == 2
|
61
|
+
to_gnuplot + " w l lw 3 title '#{weight.round((DIGIT))}*N(#{@mu.map{|mu| mu.round((DIGIT))}.to_a.inspect},#{@sigma2.map{|sigma2| sigma2.round((DIGIT))}.to_a.inspect})'"
|
62
|
+
else
|
63
|
+
"N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
@@ -0,0 +1,122 @@
|
|
1
|
+
module EMAlgorithm
|
2
|
+
class Mixture < Model
|
3
|
+
attr_accessor :models, :weights
|
4
|
+
|
5
|
+
def initialize(options)
|
6
|
+
opts = {
|
7
|
+
:models =>
|
8
|
+
[
|
9
|
+
Gaussian.new(0.0, 9.0), Gaussian.new(10.0, 9.0)
|
10
|
+
],
|
11
|
+
:weights =>
|
12
|
+
[
|
13
|
+
0.5, 0.5
|
14
|
+
]
|
15
|
+
}.merge(options)
|
16
|
+
@models = opts[:models]
|
17
|
+
@weights = opts[:weights]
|
18
|
+
if !proper_weights?
|
19
|
+
argument_error
|
20
|
+
end
|
21
|
+
@temp_weights = Array.new(@models.size)
|
22
|
+
@temp_weight_per_datum = Array.new(@models.size).map { Array.new }
|
23
|
+
end
|
24
|
+
|
25
|
+
def argument_error
|
26
|
+
raise ArgumentError, "The summation of @weights must be equal to 1.0."
|
27
|
+
end
|
28
|
+
|
29
|
+
def add(model = Gaussian.new(0.0, 9.0), weight = 0.0)
|
30
|
+
@models << model
|
31
|
+
@weights << weight
|
32
|
+
if !proper_weights?
|
33
|
+
argument_error
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
def proper_weights?
|
38
|
+
@weights.inject(0) {|sum, v| sum + v} == 1.0
|
39
|
+
end
|
40
|
+
|
41
|
+
def probability_density_function(x)
|
42
|
+
pdf = 0.0
|
43
|
+
@models.each_with_index do |model, mi|
|
44
|
+
pdf += model.pdf(x) * @weights[mi]
|
45
|
+
end
|
46
|
+
pdf
|
47
|
+
end
|
48
|
+
|
49
|
+
def clear_temp_weight_per_datum!
|
50
|
+
@temp_weight_per_datum.each {|w| w.clear}
|
51
|
+
end
|
52
|
+
|
53
|
+
def calculate_posterior_data_array(data_array)
|
54
|
+
posterior_data_array = Array.new(data_array.size, 0.0)
|
55
|
+
@models.each_with_index do |model, mi|
|
56
|
+
data_array.each_with_index do |x, di|
|
57
|
+
posterior_data_array[di] += @weights[mi] * model.pdf(x)
|
58
|
+
end
|
59
|
+
end
|
60
|
+
posterior_data_array
|
61
|
+
end
|
62
|
+
|
63
|
+
def update_temp_weights!(data_array, posterior_data_array)
|
64
|
+
@models.each_with_index do |model, mi|
|
65
|
+
data_array.each_with_index do |x, di|
|
66
|
+
temp_weight_per_datum = @weights[mi] * model.pdf(x) / posterior_data_array[di]
|
67
|
+
temp_weight_per_datum = 0.0 if temp_weight_per_datum.nan?
|
68
|
+
@temp_weight_per_datum[mi] << temp_weight_per_datum
|
69
|
+
end
|
70
|
+
@temp_weights[mi] = @temp_weight_per_datum[mi].inject(0.0) {|sum, w| sum + w}
|
71
|
+
end
|
72
|
+
end
|
73
|
+
|
74
|
+
def update_weights!(data_array)
|
75
|
+
(0..(@models.size-1)).each do |mi|
|
76
|
+
@weights[mi] = @temp_weights[mi] / data_array.size
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
def update_parameters!(data_array)
|
81
|
+
@models.each_with_index do |model, mi|
|
82
|
+
model.update_parameters!(data_array, @temp_weights[mi], @temp_weight_per_datum[mi])
|
83
|
+
end
|
84
|
+
update_weights!(data_array)
|
85
|
+
end
|
86
|
+
|
87
|
+
# output types
|
88
|
+
# :full (default)
|
89
|
+
# :separate_only
|
90
|
+
# :mixture_only
|
91
|
+
def to_gnuplot(type = :full)
|
92
|
+
# output each model (currently assume Gaussian)
|
93
|
+
output = []
|
94
|
+
@models.each_with_index do |model, mi|
|
95
|
+
output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot_with_title(@weights[mi])}"
|
96
|
+
end
|
97
|
+
separate = output.join(", ")
|
98
|
+
# output mixture model (currently assume Gaussian Mixture model)
|
99
|
+
output = []
|
100
|
+
@models.each_with_index do |model, mi|
|
101
|
+
output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot}"
|
102
|
+
end
|
103
|
+
mixture = output.join(" + ")
|
104
|
+
case type
|
105
|
+
when :separate_only
|
106
|
+
return separate
|
107
|
+
when :mixture_only
|
108
|
+
return mixture
|
109
|
+
end
|
110
|
+
"#{separate}, #{mixture}"
|
111
|
+
end
|
112
|
+
|
113
|
+
def debug_output
|
114
|
+
<<-DEBUG_OUT
|
115
|
+
@weights=#{@weights.inspect}
|
116
|
+
@temp_weights=#{@temp_weights.inspect}
|
117
|
+
@models
|
118
|
+
#{@models.inspect}
|
119
|
+
DEBUG_OUT
|
120
|
+
end
|
121
|
+
end
|
122
|
+
end
|
@@ -0,0 +1,122 @@
|
|
1
|
+
module EMAlgorithm
|
2
|
+
class Mixture < Model
|
3
|
+
attr_accessor :models, :weights, :const
|
4
|
+
|
5
|
+
def initialize(options)
|
6
|
+
opts = {
|
7
|
+
:models =>
|
8
|
+
[
|
9
|
+
Gaussian.new(0.0, 9.0), Gaussian.new(10.0, 9.0)
|
10
|
+
],
|
11
|
+
:weights =>
|
12
|
+
[
|
13
|
+
0.5, 0.5
|
14
|
+
]
|
15
|
+
}.merge(options)
|
16
|
+
@models = opts[:models]
|
17
|
+
@weights = opts[:weights]
|
18
|
+
if !proper_weights?
|
19
|
+
argument_error
|
20
|
+
end
|
21
|
+
@temp_weights = Array.new(@models.size)
|
22
|
+
@temp_weight_per_datum = Array.new(@models.size).map { Array.new }
|
23
|
+
end
|
24
|
+
|
25
|
+
def argument_error
|
26
|
+
raise ArgumentError, "The summation of @weights must be equal to 1.0."
|
27
|
+
end
|
28
|
+
|
29
|
+
def add(model = Gaussian.new(0.0, 9.0), weight = 0.0)
|
30
|
+
@models << model
|
31
|
+
@weights << weight
|
32
|
+
if !proper_weights?
|
33
|
+
argument_error
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
def proper_weights?
|
38
|
+
@weights.inject(0) {|sum, v| sum + v} == 1.0
|
39
|
+
end
|
40
|
+
|
41
|
+
def probability_density_function(x)
|
42
|
+
pdf = 0.0
|
43
|
+
@models.each_with_index do |model, mi|
|
44
|
+
pdf += model.pdf(x) * @weights[mi]
|
45
|
+
end
|
46
|
+
pdf
|
47
|
+
end
|
48
|
+
|
49
|
+
def clear_temp_weight_per_datum!
|
50
|
+
@temp_weight_per_datum.each {|w| w.clear}
|
51
|
+
end
|
52
|
+
|
53
|
+
def calculate_posterior_data_array(data_array)
|
54
|
+
posterior_data_array = Array.new(data_array.size, 0.0)
|
55
|
+
@models.each_with_index do |model, mi|
|
56
|
+
data_array.each_with_index do |x, di|
|
57
|
+
posterior_data_array[di] += @weights[mi] * model.pdf(x)
|
58
|
+
end
|
59
|
+
end
|
60
|
+
posterior_data_array
|
61
|
+
end
|
62
|
+
|
63
|
+
def update_temp_weights!(data_array, posterior_data_array)
|
64
|
+
@models.each_with_index do |model, mi|
|
65
|
+
data_array.each_with_index do |x, di|
|
66
|
+
temp_weight_per_datum = @weights[mi] * model.pdf(x) / posterior_data_array[di]
|
67
|
+
temp_weight_per_datum = 0.0 if temp_weight_per_datum.nan?
|
68
|
+
@temp_weight_per_datum[mi] << temp_weight_per_datum
|
69
|
+
end
|
70
|
+
@temp_weights[mi] = @temp_weight_per_datum[mi].inject(0.0) {|sum, w| sum + w}
|
71
|
+
end
|
72
|
+
end
|
73
|
+
|
74
|
+
def update_weights!(data_array)
|
75
|
+
(0..(@models.size-1)).each do |mi|
|
76
|
+
@weights[mi] = @temp_weights[mi] / data_array.size
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
def update_parameters!(data_array)
|
81
|
+
@models.each_with_index do |model, mi|
|
82
|
+
model.update_parameters!(data_array, @temp_weights[mi], @temp_weight_per_datum[mi])
|
83
|
+
end
|
84
|
+
update_weights!(data_array)
|
85
|
+
end
|
86
|
+
|
87
|
+
# output types
|
88
|
+
# :full (default)
|
89
|
+
# :separate_only
|
90
|
+
# :mixture_only
|
91
|
+
def to_gnuplot(type = :full)
|
92
|
+
# output each model (currently assume Gaussian)
|
93
|
+
output = []
|
94
|
+
@models.each_with_index do |model, mi|
|
95
|
+
output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot_with_title(@weights[mi])}"
|
96
|
+
end
|
97
|
+
separate = output.join(", ")
|
98
|
+
# output mixture model (currently assume Gaussian Mixture model)
|
99
|
+
output = []
|
100
|
+
@models.each_with_index do |model, mi|
|
101
|
+
output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot}"
|
102
|
+
end
|
103
|
+
mixture = output.join(" + ")
|
104
|
+
case type
|
105
|
+
when :separate_only
|
106
|
+
return separate
|
107
|
+
when :mixture_only
|
108
|
+
return mixture
|
109
|
+
end
|
110
|
+
"#{separate}, #{mixture}"
|
111
|
+
end
|
112
|
+
|
113
|
+
def debug_output
|
114
|
+
<<-DEBUG_OUT
|
115
|
+
@weights=#{@weights.inspect}
|
116
|
+
@temp_weights=#{@temp_weights.inspect}
|
117
|
+
@models
|
118
|
+
#{@models.inspect}
|
119
|
+
DEBUG_OUT
|
120
|
+
end
|
121
|
+
end
|
122
|
+
end
|