ruby-em_algorithm 0.0.2
Sign up to get free protection for your applications and to get access to all the features.
- data/Gemfile +6 -0
- data/Gemfile.lock +30 -0
- data/README.md +44 -0
- data/Rakefile +7 -0
- data/example/.ex1.rb.swp +0 -0
- data/example/.ex2.rb.swp +0 -0
- data/example/.ex3-tmp.rb.swp +0 -0
- data/example/.ex3.rb.swp +0 -0
- data/example/data/2dim-gmm-new.txt +1267 -0
- data/example/data/2dim-gmm-simple.txt +676 -0
- data/example/data/2dim-gmm-test.txt +6565 -0
- data/example/data/2dim-gmm-test2.txt +2782 -0
- data/example/data/2dim-gmm-test3.csv +1641 -0
- data/example/data/2dim-gmm-test3.txt +2782 -0
- data/example/data/2dim-gmm-test4.csv +868 -0
- data/example/data/2dim-gmm-test4.txt +4924 -0
- data/example/data/2dim-gmm-without_weight-small.txt +2401 -0
- data/example/data/2dim-gmm-without_weight.txt +18001 -0
- data/example/data/2dim-gmm.txt +1267 -0
- data/example/data/gmm-new.txt +10001 -0
- data/example/data/gmm-simple.txt +676 -0
- data/example/data/gmm.txt +10001 -0
- data/example/data/old-gmm.txt +10000 -0
- data/example/ex1.rb +20 -0
- data/example/ex1.rb~ +20 -0
- data/example/ex2.rb +33 -0
- data/example/ex2.rb~ +33 -0
- data/example/ex3-tmp.rb +23 -0
- data/example/ex3-tmp.rb~ +25 -0
- data/example/ex3.rb +43 -0
- data/example/ex3.rb~ +43 -0
- data/example/tools/.2dim.rb.swp +0 -0
- data/example/tools/2dim.rb +69 -0
- data/example/tools/2dim.rb~ +69 -0
- data/example/tools/boxmuller.rb +28 -0
- data/example/tools/boxmuller.rb~ +28 -0
- data/example/tools/conv_from_yaml.rb +8 -0
- data/example/tools/conv_from_yaml_to_csv.rb +8 -0
- data/example/tools/conv_to_yaml.rb +17 -0
- data/example/tools/ellipsoid.gnuplot +63 -0
- data/example/tools/ellipsoid.gnuplot~ +64 -0
- data/example/tools/histogram.rb +19 -0
- data/example/tools/histogram2d.rb +20 -0
- data/example/tools/histogram2d.rb~ +18 -0
- data/example/tools/kmeans.rb +34 -0
- data/example/tools/mean.rb +19 -0
- data/example/tools/table.data +4618 -0
- data/example/tools/tmp.txt +69632 -0
- data/example/tools/xmeans.R +608 -0
- data/example/tools/xmeans.rb +35 -0
- data/lib/em_algorithm/.base.rb.swp +0 -0
- data/lib/em_algorithm/base.rb +116 -0
- data/lib/em_algorithm/base.rb~ +116 -0
- data/lib/em_algorithm/convergence/.chi_square.rb.swp +0 -0
- data/lib/em_algorithm/convergence/.likelihood.rb.swp +0 -0
- data/lib/em_algorithm/convergence/check_method.rb +4 -0
- data/lib/em_algorithm/convergence/check_method.rb~ +0 -0
- data/lib/em_algorithm/convergence/chi_square.rb +40 -0
- data/lib/em_algorithm/convergence/chi_square.rb~ +40 -0
- data/lib/em_algorithm/convergence/likelihood.rb +35 -0
- data/lib/em_algorithm/convergence/likelihood.rb~ +35 -0
- data/lib/em_algorithm/models/.gaussian.rb.swp +0 -0
- data/lib/em_algorithm/models/.md_gaussian.rb.swp +0 -0
- data/lib/em_algorithm/models/.mixture.rb.swp +0 -0
- data/lib/em_algorithm/models/.model.rb.swp +0 -0
- data/lib/em_algorithm/models/gaussian.rb +47 -0
- data/lib/em_algorithm/models/gaussian.rb~ +47 -0
- data/lib/em_algorithm/models/md_gaussian.rb +67 -0
- data/lib/em_algorithm/models/md_gaussian.rb~ +67 -0
- data/lib/em_algorithm/models/mixture.rb +122 -0
- data/lib/em_algorithm/models/mixture.rb~ +122 -0
- data/lib/em_algorithm/models/model.rb +19 -0
- data/lib/em_algorithm/models/model.rb~ +19 -0
- data/lib/ruby-em_algorithm.rb +3 -0
- data/lib/ruby-em_algorithm/version.rb +3 -0
- data/ruby-em_algorithm.gemspec +21 -0
- data/spec/spec_helper.rb +9 -0
- metadata +178 -0
@@ -0,0 +1,47 @@
|
|
1
|
+
module EMAlgorithm
|
2
|
+
class Gaussian < Model
|
3
|
+
attr_accessor :mu, :sigma, :dim
|
4
|
+
|
5
|
+
def initialize(mu = 0.0, sigma = 1.0)
|
6
|
+
@mu = mu
|
7
|
+
@sigma = sigma
|
8
|
+
end
|
9
|
+
|
10
|
+
def probability_density_function(x)
|
11
|
+
exp(-((x-@mu)**2.0)/(2.0*@sigma**2))/(sqrt(2.0*PI)*@sigma)
|
12
|
+
end
|
13
|
+
|
14
|
+
def probability_density_function_with_observation_weight(x_with_weight)
|
15
|
+
x = x_with_weight[0]
|
16
|
+
observation_weight = x_with_weight[1]
|
17
|
+
observation_weight * probability_density_function(x)
|
18
|
+
end
|
19
|
+
|
20
|
+
def update_average!(data_array, temp_weight, temp_weight_per_datum)
|
21
|
+
data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
|
22
|
+
sum + temp_weight_per_datum[di] * data_array[di]
|
23
|
+
end
|
24
|
+
@mu = data_sum / temp_weight
|
25
|
+
end
|
26
|
+
|
27
|
+
def update_sigma!(data_array, temp_weight, temp_weight_per_datum)
|
28
|
+
data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
|
29
|
+
sum + temp_weight_per_datum[di] * (data_array[di] - @mu) ** 2
|
30
|
+
end
|
31
|
+
@sigma = sqrt(data_sum / temp_weight)
|
32
|
+
end
|
33
|
+
|
34
|
+
def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
|
35
|
+
update_average!(data_array, temp_weight, temp_weight_per_datum)
|
36
|
+
update_sigma!(data_array, temp_weight, temp_weight_per_datum)
|
37
|
+
end
|
38
|
+
|
39
|
+
def to_gnuplot
|
40
|
+
"exp(-((x-(#{@mu.round(DIGIT)}))**2.0)/(2.0*#{(@sigma ** 2).round(DIGIT)}))/(sqrt(2.0*pi)*#{@sigma.round(DIGIT)})"
|
41
|
+
end
|
42
|
+
|
43
|
+
def to_gnuplot_with_title(weight)
|
44
|
+
to_gnuplot + " w l axis x1y2 lw 3 title '#{weight.round(DIGIT)}*N(#{@mu.round(DIGIT)},#{@sigma.round(DIGIT)})'"
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
@@ -0,0 +1,47 @@
|
|
1
|
+
module EMAlgorithm
|
2
|
+
class Gaussian < Model
|
3
|
+
attr_accessor :mu, :sigma, :dim
|
4
|
+
|
5
|
+
def initialize(mu = 0.0, sigma = 1.0)
|
6
|
+
@mu = mu
|
7
|
+
@sigma = sigma
|
8
|
+
end
|
9
|
+
|
10
|
+
def probability_density_function(x)
|
11
|
+
exp(-((x-@mu)**2.0)/(2.0*@sigma**2))/(sqrt(2.0*PI)*@sigma)
|
12
|
+
end
|
13
|
+
|
14
|
+
def probability_density_function_with_observation_weight(x_with_weight)
|
15
|
+
x = x_with_weight[0]
|
16
|
+
observation_weight = x_with_weight[1]
|
17
|
+
observation_weight * probability_density_function(x)
|
18
|
+
end
|
19
|
+
|
20
|
+
def update_average!(data_array, temp_weight, temp_weight_per_datum)
|
21
|
+
data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
|
22
|
+
sum + temp_weight_per_datum[di] * data_array[di]
|
23
|
+
end
|
24
|
+
@mu = data_sum / temp_weight
|
25
|
+
end
|
26
|
+
|
27
|
+
def update_sigma!(data_array, temp_weight, temp_weight_per_datum)
|
28
|
+
data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
|
29
|
+
sum + temp_weight_per_datum[di] * (data_array[di] - @mu) ** 2
|
30
|
+
end
|
31
|
+
@sigma = sqrt(data_sum / temp_weight)
|
32
|
+
end
|
33
|
+
|
34
|
+
def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
|
35
|
+
update_average!(data_array, temp_weight, temp_weight_per_datum)
|
36
|
+
update_sigma!(data_array, temp_weight, temp_weight_per_datum)
|
37
|
+
end
|
38
|
+
|
39
|
+
def to_gnuplot
|
40
|
+
"exp(-((x-(#{@mu.round(DIGIT)}))**2.0)/(2.0*#{(@sigma ** 2).round(DIGIT)}))/(sqrt(2.0*pi)*#{@sigma.round(DIGIT)})"
|
41
|
+
end
|
42
|
+
|
43
|
+
def to_gnuplot_with_title(weight)
|
44
|
+
to_gnuplot + " w l axis x1y2 lw 3 title '#{weight.round(DIGIT)}*N(#{@mu.round(DIGIT)},#{@sigma.round(DIGIT)})'"
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
@@ -0,0 +1,67 @@
|
|
1
|
+
module EMAlgorithm
|
2
|
+
class MdGaussian < Model
|
3
|
+
attr_accessor :mu, :sigma2
|
4
|
+
|
5
|
+
def initialize(mu = GSL::Vector[0.0, 0.0], sigma2 = GSL::Matrix[[1.0, 0.0], [0.0, 1.0]])
|
6
|
+
# check mu
|
7
|
+
if mu.class != GSL::Vector
|
8
|
+
raise ArgumentError, "mu should be GSL::Vector."
|
9
|
+
end
|
10
|
+
@mu = mu
|
11
|
+
# check sigma2
|
12
|
+
if sigma2.class != GSL::Matrix
|
13
|
+
raise ArgumentError, "sigma2 should be GSL::Matrix."
|
14
|
+
elsif sigma2.size1 != @mu.size || sigma2.size2 != @mu.size
|
15
|
+
raise ArgumentError, "The size of sigma2 matrix does not match with mu vector."
|
16
|
+
end
|
17
|
+
@sigma2 = sigma2
|
18
|
+
@sqrt_sigma2_det = sqrt(@sigma2.det)
|
19
|
+
@sigma2_invert = @sigma2.invert
|
20
|
+
end
|
21
|
+
|
22
|
+
def probability_density_function(x)
|
23
|
+
exp(-((x-@mu) * @sigma2_invert * (x-@mu).trans)/2.0)/((sqrt(2.0*PI)**@mu.size)*@sqrt_sigma2_det)
|
24
|
+
end
|
25
|
+
|
26
|
+
def update_average!(data_array, temp_weight, temp_weight_per_datum)
|
27
|
+
data_sum = (0..(data_array.size-1)).inject(GSL::Vector.alloc(@mu.size).set_zero) do |sum, di|
|
28
|
+
sum + temp_weight_per_datum[di] * data_array[di]
|
29
|
+
end
|
30
|
+
@mu = data_sum / temp_weight
|
31
|
+
end
|
32
|
+
|
33
|
+
def update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
|
34
|
+
data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
|
35
|
+
sum + temp_weight_per_datum[di] * (data_array[di] - @mu).trans * (data_array[di] - @mu)
|
36
|
+
end
|
37
|
+
@sigma2 = (data_sum / temp_weight)
|
38
|
+
@sqrt_sigma2_det = sqrt(@sigma2.det)
|
39
|
+
@sigma2_invert = @sigma2.invert
|
40
|
+
end
|
41
|
+
|
42
|
+
def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
|
43
|
+
update_average!(data_array, temp_weight, temp_weight_per_datum)
|
44
|
+
update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
|
45
|
+
end
|
46
|
+
|
47
|
+
def to_gnuplot
|
48
|
+
if @mu.size == 2
|
49
|
+
# [x - mu_x, y - mu_y] * [[s_x, s_xy], [s_xy, s_y]] * [x - mu_x, y - mu_y]
|
50
|
+
# = s_x*(x - mu_x)**2 + 2*s_xy*(x - mu_x)(y - mu_y) + s_y*(y - mu_y)**2
|
51
|
+
sigma2_xy = @sigma2[0,1] + @sigma2[1,0]
|
52
|
+
xy = "+(#{sigma2_xy.round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))*(y-(#{@mu[1].round((DIGIT))}))" if sigma2_xy > 0 || sigma2_xy < 0
|
53
|
+
"exp(-((#{@sigma2[0,0].round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))**2.0+(#{@sigma2[1,1].round((DIGIT))})*(y-(#{@mu[1].round((DIGIT))}))**2.0#{xy})/2.0)/((sqrt(2.0*pi))**#{@mu.size}*(#{@sqrt_sigma2_det.round((DIGIT))}))"
|
54
|
+
else
|
55
|
+
"N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
def to_gnuplot_with_title(weight)
|
60
|
+
if @mu.size == 2
|
61
|
+
to_gnuplot + " w l lw 3 title '#{weight.round((DIGIT))}*N(#{@mu.map{|mu| mu.round((DIGIT))}.to_a.inspect},#{@sigma2.map{|sigma2| sigma2.round((DIGIT))}.to_a.inspect})'"
|
62
|
+
else
|
63
|
+
"N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
@@ -0,0 +1,67 @@
|
|
1
|
+
module EMAlgorithm
|
2
|
+
class MdGaussian < Model
|
3
|
+
attr_accessor :mu, :sigma2
|
4
|
+
|
5
|
+
def initialize(mu = GSL::Vector[0.0, 0.0], sigma2 = GSL::Matrix[[1.0, 0.0], [0.0, 1.0]])
|
6
|
+
# check mu
|
7
|
+
if mu.class != GSL::Vector
|
8
|
+
raise ArgumentError, "mu should be GSL::Vector."
|
9
|
+
end
|
10
|
+
@mu = mu
|
11
|
+
# check sigma2
|
12
|
+
if sigma2.class != GSL::Matrix
|
13
|
+
raise ArgumentError, "sigma2 should be GSL::Matrix."
|
14
|
+
elsif sigma2.size1 != @mu.size || sigma2.size2 != @mu.size
|
15
|
+
raise ArgumentError, "The size of sigma2 matrix does not match with mu vector."
|
16
|
+
end
|
17
|
+
@sigma2 = sigma2
|
18
|
+
@sqrt_sigma2_det = sqrt(@sigma2.det)
|
19
|
+
@sigma2_invert = @sigma2.invert
|
20
|
+
end
|
21
|
+
|
22
|
+
def probability_density_function(x)
|
23
|
+
exp(-((x-@mu) * @sigma2_invert * (x-@mu).trans)/2.0)/((sqrt(2.0*PI)**@mu.size)*@sqrt_sigma2_det)
|
24
|
+
end
|
25
|
+
|
26
|
+
def update_average!(data_array, temp_weight, temp_weight_per_datum)
|
27
|
+
data_sum = (0..(data_array.size-1)).inject(GSL::Vector.alloc(@mu.size).set_zero) do |sum, di|
|
28
|
+
sum + temp_weight_per_datum[di] * data_array[di]
|
29
|
+
end
|
30
|
+
@mu = data_sum / temp_weight
|
31
|
+
end
|
32
|
+
|
33
|
+
def update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
|
34
|
+
data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
|
35
|
+
sum + temp_weight_per_datum[di] * (data_array[di] - @mu).trans * (data_array[di] - @mu)
|
36
|
+
end
|
37
|
+
@sigma2 = (data_sum / temp_weight)
|
38
|
+
@sqrt_sigma2_det = sqrt(@sigma2.det)
|
39
|
+
@sigma2_invert = @sigma2.invert
|
40
|
+
end
|
41
|
+
|
42
|
+
def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
|
43
|
+
update_average!(data_array, temp_weight, temp_weight_per_datum)
|
44
|
+
update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
|
45
|
+
end
|
46
|
+
|
47
|
+
def to_gnuplot
|
48
|
+
if @mu.size == 2
|
49
|
+
# [x - mu_x, y - mu_y] * [[s_x, s_xy], [s_xy, s_y]] * [x - mu_x, y - mu_y]
|
50
|
+
# = s_x*(x - mu_x)**2 + 2*s_xy*(x - mu_x)(y - mu_y) + s_y*(y - mu_y)**2
|
51
|
+
sigma2_xy = @sigma2[0,1] + @sigma2[1,0]
|
52
|
+
xy = "+(#{sigma2_xy.round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))*(y-(#{@mu[1].round((DIGIT))}))" if sigma2_xy > 0 || sigma2_xy < 0
|
53
|
+
"exp(-((#{@sigma2[0,0].round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))**2.0+(#{@sigma2[1,1].round((DIGIT))})*(y-(#{@mu[1].round((DIGIT))}))**2.0#{xy})/2.0)/((sqrt(2.0*pi))**#{@mu.size}*(#{@sqrt_sigma2_det.round((DIGIT))}))"
|
54
|
+
else
|
55
|
+
"N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
def to_gnuplot_with_title(weight)
|
60
|
+
if @mu.size == 2
|
61
|
+
to_gnuplot + " w l lw 3 title '#{weight.round((DIGIT))}*N(#{@mu.map{|mu| mu.round((DIGIT))}.to_a.inspect},#{@sigma2.map{|sigma2| sigma2.round((DIGIT))}.to_a.inspect})'"
|
62
|
+
else
|
63
|
+
"N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
@@ -0,0 +1,122 @@
|
|
1
|
+
module EMAlgorithm
|
2
|
+
class Mixture < Model
|
3
|
+
attr_accessor :models, :weights
|
4
|
+
|
5
|
+
def initialize(options)
|
6
|
+
opts = {
|
7
|
+
:models =>
|
8
|
+
[
|
9
|
+
Gaussian.new(0.0, 9.0), Gaussian.new(10.0, 9.0)
|
10
|
+
],
|
11
|
+
:weights =>
|
12
|
+
[
|
13
|
+
0.5, 0.5
|
14
|
+
]
|
15
|
+
}.merge(options)
|
16
|
+
@models = opts[:models]
|
17
|
+
@weights = opts[:weights]
|
18
|
+
if !proper_weights?
|
19
|
+
argument_error
|
20
|
+
end
|
21
|
+
@temp_weights = Array.new(@models.size)
|
22
|
+
@temp_weight_per_datum = Array.new(@models.size).map { Array.new }
|
23
|
+
end
|
24
|
+
|
25
|
+
def argument_error
|
26
|
+
raise ArgumentError, "The summation of @weights must be equal to 1.0."
|
27
|
+
end
|
28
|
+
|
29
|
+
def add(model = Gaussian.new(0.0, 9.0), weight = 0.0)
|
30
|
+
@models << model
|
31
|
+
@weights << weight
|
32
|
+
if !proper_weights?
|
33
|
+
argument_error
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
def proper_weights?
|
38
|
+
@weights.inject(0) {|sum, v| sum + v} == 1.0
|
39
|
+
end
|
40
|
+
|
41
|
+
def probability_density_function(x)
|
42
|
+
pdf = 0.0
|
43
|
+
@models.each_with_index do |model, mi|
|
44
|
+
pdf += model.pdf(x) * @weights[mi]
|
45
|
+
end
|
46
|
+
pdf
|
47
|
+
end
|
48
|
+
|
49
|
+
def clear_temp_weight_per_datum!
|
50
|
+
@temp_weight_per_datum.each {|w| w.clear}
|
51
|
+
end
|
52
|
+
|
53
|
+
def calculate_posterior_data_array(data_array)
|
54
|
+
posterior_data_array = Array.new(data_array.size, 0.0)
|
55
|
+
@models.each_with_index do |model, mi|
|
56
|
+
data_array.each_with_index do |x, di|
|
57
|
+
posterior_data_array[di] += @weights[mi] * model.pdf(x)
|
58
|
+
end
|
59
|
+
end
|
60
|
+
posterior_data_array
|
61
|
+
end
|
62
|
+
|
63
|
+
def update_temp_weights!(data_array, posterior_data_array)
|
64
|
+
@models.each_with_index do |model, mi|
|
65
|
+
data_array.each_with_index do |x, di|
|
66
|
+
temp_weight_per_datum = @weights[mi] * model.pdf(x) / posterior_data_array[di]
|
67
|
+
temp_weight_per_datum = 0.0 if temp_weight_per_datum.nan?
|
68
|
+
@temp_weight_per_datum[mi] << temp_weight_per_datum
|
69
|
+
end
|
70
|
+
@temp_weights[mi] = @temp_weight_per_datum[mi].inject(0.0) {|sum, w| sum + w}
|
71
|
+
end
|
72
|
+
end
|
73
|
+
|
74
|
+
def update_weights!(data_array)
|
75
|
+
(0..(@models.size-1)).each do |mi|
|
76
|
+
@weights[mi] = @temp_weights[mi] / data_array.size
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
def update_parameters!(data_array)
|
81
|
+
@models.each_with_index do |model, mi|
|
82
|
+
model.update_parameters!(data_array, @temp_weights[mi], @temp_weight_per_datum[mi])
|
83
|
+
end
|
84
|
+
update_weights!(data_array)
|
85
|
+
end
|
86
|
+
|
87
|
+
# output types
|
88
|
+
# :full (default)
|
89
|
+
# :separate_only
|
90
|
+
# :mixture_only
|
91
|
+
def to_gnuplot(type = :full)
|
92
|
+
# output each model (currently assume Gaussian)
|
93
|
+
output = []
|
94
|
+
@models.each_with_index do |model, mi|
|
95
|
+
output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot_with_title(@weights[mi])}"
|
96
|
+
end
|
97
|
+
separate = output.join(", ")
|
98
|
+
# output mixture model (currently assume Gaussian Mixture model)
|
99
|
+
output = []
|
100
|
+
@models.each_with_index do |model, mi|
|
101
|
+
output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot}"
|
102
|
+
end
|
103
|
+
mixture = output.join(" + ")
|
104
|
+
case type
|
105
|
+
when :separate_only
|
106
|
+
return separate
|
107
|
+
when :mixture_only
|
108
|
+
return mixture
|
109
|
+
end
|
110
|
+
"#{separate}, #{mixture}"
|
111
|
+
end
|
112
|
+
|
113
|
+
def debug_output
|
114
|
+
<<-DEBUG_OUT
|
115
|
+
@weights=#{@weights.inspect}
|
116
|
+
@temp_weights=#{@temp_weights.inspect}
|
117
|
+
@models
|
118
|
+
#{@models.inspect}
|
119
|
+
DEBUG_OUT
|
120
|
+
end
|
121
|
+
end
|
122
|
+
end
|
@@ -0,0 +1,122 @@
|
|
1
|
+
module EMAlgorithm
|
2
|
+
class Mixture < Model
|
3
|
+
attr_accessor :models, :weights, :const
|
4
|
+
|
5
|
+
def initialize(options)
|
6
|
+
opts = {
|
7
|
+
:models =>
|
8
|
+
[
|
9
|
+
Gaussian.new(0.0, 9.0), Gaussian.new(10.0, 9.0)
|
10
|
+
],
|
11
|
+
:weights =>
|
12
|
+
[
|
13
|
+
0.5, 0.5
|
14
|
+
]
|
15
|
+
}.merge(options)
|
16
|
+
@models = opts[:models]
|
17
|
+
@weights = opts[:weights]
|
18
|
+
if !proper_weights?
|
19
|
+
argument_error
|
20
|
+
end
|
21
|
+
@temp_weights = Array.new(@models.size)
|
22
|
+
@temp_weight_per_datum = Array.new(@models.size).map { Array.new }
|
23
|
+
end
|
24
|
+
|
25
|
+
def argument_error
|
26
|
+
raise ArgumentError, "The summation of @weights must be equal to 1.0."
|
27
|
+
end
|
28
|
+
|
29
|
+
def add(model = Gaussian.new(0.0, 9.0), weight = 0.0)
|
30
|
+
@models << model
|
31
|
+
@weights << weight
|
32
|
+
if !proper_weights?
|
33
|
+
argument_error
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
def proper_weights?
|
38
|
+
@weights.inject(0) {|sum, v| sum + v} == 1.0
|
39
|
+
end
|
40
|
+
|
41
|
+
def probability_density_function(x)
|
42
|
+
pdf = 0.0
|
43
|
+
@models.each_with_index do |model, mi|
|
44
|
+
pdf += model.pdf(x) * @weights[mi]
|
45
|
+
end
|
46
|
+
pdf
|
47
|
+
end
|
48
|
+
|
49
|
+
def clear_temp_weight_per_datum!
|
50
|
+
@temp_weight_per_datum.each {|w| w.clear}
|
51
|
+
end
|
52
|
+
|
53
|
+
def calculate_posterior_data_array(data_array)
|
54
|
+
posterior_data_array = Array.new(data_array.size, 0.0)
|
55
|
+
@models.each_with_index do |model, mi|
|
56
|
+
data_array.each_with_index do |x, di|
|
57
|
+
posterior_data_array[di] += @weights[mi] * model.pdf(x)
|
58
|
+
end
|
59
|
+
end
|
60
|
+
posterior_data_array
|
61
|
+
end
|
62
|
+
|
63
|
+
def update_temp_weights!(data_array, posterior_data_array)
|
64
|
+
@models.each_with_index do |model, mi|
|
65
|
+
data_array.each_with_index do |x, di|
|
66
|
+
temp_weight_per_datum = @weights[mi] * model.pdf(x) / posterior_data_array[di]
|
67
|
+
temp_weight_per_datum = 0.0 if temp_weight_per_datum.nan?
|
68
|
+
@temp_weight_per_datum[mi] << temp_weight_per_datum
|
69
|
+
end
|
70
|
+
@temp_weights[mi] = @temp_weight_per_datum[mi].inject(0.0) {|sum, w| sum + w}
|
71
|
+
end
|
72
|
+
end
|
73
|
+
|
74
|
+
def update_weights!(data_array)
|
75
|
+
(0..(@models.size-1)).each do |mi|
|
76
|
+
@weights[mi] = @temp_weights[mi] / data_array.size
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
def update_parameters!(data_array)
|
81
|
+
@models.each_with_index do |model, mi|
|
82
|
+
model.update_parameters!(data_array, @temp_weights[mi], @temp_weight_per_datum[mi])
|
83
|
+
end
|
84
|
+
update_weights!(data_array)
|
85
|
+
end
|
86
|
+
|
87
|
+
# output types
|
88
|
+
# :full (default)
|
89
|
+
# :separate_only
|
90
|
+
# :mixture_only
|
91
|
+
def to_gnuplot(type = :full)
|
92
|
+
# output each model (currently assume Gaussian)
|
93
|
+
output = []
|
94
|
+
@models.each_with_index do |model, mi|
|
95
|
+
output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot_with_title(@weights[mi])}"
|
96
|
+
end
|
97
|
+
separate = output.join(", ")
|
98
|
+
# output mixture model (currently assume Gaussian Mixture model)
|
99
|
+
output = []
|
100
|
+
@models.each_with_index do |model, mi|
|
101
|
+
output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot}"
|
102
|
+
end
|
103
|
+
mixture = output.join(" + ")
|
104
|
+
case type
|
105
|
+
when :separate_only
|
106
|
+
return separate
|
107
|
+
when :mixture_only
|
108
|
+
return mixture
|
109
|
+
end
|
110
|
+
"#{separate}, #{mixture}"
|
111
|
+
end
|
112
|
+
|
113
|
+
def debug_output
|
114
|
+
<<-DEBUG_OUT
|
115
|
+
@weights=#{@weights.inspect}
|
116
|
+
@temp_weights=#{@temp_weights.inspect}
|
117
|
+
@models
|
118
|
+
#{@models.inspect}
|
119
|
+
DEBUG_OUT
|
120
|
+
end
|
121
|
+
end
|
122
|
+
end
|