ruby-em_algorithm 0.0.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (78) hide show
  1. data/Gemfile +6 -0
  2. data/Gemfile.lock +30 -0
  3. data/README.md +44 -0
  4. data/Rakefile +7 -0
  5. data/example/.ex1.rb.swp +0 -0
  6. data/example/.ex2.rb.swp +0 -0
  7. data/example/.ex3-tmp.rb.swp +0 -0
  8. data/example/.ex3.rb.swp +0 -0
  9. data/example/data/2dim-gmm-new.txt +1267 -0
  10. data/example/data/2dim-gmm-simple.txt +676 -0
  11. data/example/data/2dim-gmm-test.txt +6565 -0
  12. data/example/data/2dim-gmm-test2.txt +2782 -0
  13. data/example/data/2dim-gmm-test3.csv +1641 -0
  14. data/example/data/2dim-gmm-test3.txt +2782 -0
  15. data/example/data/2dim-gmm-test4.csv +868 -0
  16. data/example/data/2dim-gmm-test4.txt +4924 -0
  17. data/example/data/2dim-gmm-without_weight-small.txt +2401 -0
  18. data/example/data/2dim-gmm-without_weight.txt +18001 -0
  19. data/example/data/2dim-gmm.txt +1267 -0
  20. data/example/data/gmm-new.txt +10001 -0
  21. data/example/data/gmm-simple.txt +676 -0
  22. data/example/data/gmm.txt +10001 -0
  23. data/example/data/old-gmm.txt +10000 -0
  24. data/example/ex1.rb +20 -0
  25. data/example/ex1.rb~ +20 -0
  26. data/example/ex2.rb +33 -0
  27. data/example/ex2.rb~ +33 -0
  28. data/example/ex3-tmp.rb +23 -0
  29. data/example/ex3-tmp.rb~ +25 -0
  30. data/example/ex3.rb +43 -0
  31. data/example/ex3.rb~ +43 -0
  32. data/example/tools/.2dim.rb.swp +0 -0
  33. data/example/tools/2dim.rb +69 -0
  34. data/example/tools/2dim.rb~ +69 -0
  35. data/example/tools/boxmuller.rb +28 -0
  36. data/example/tools/boxmuller.rb~ +28 -0
  37. data/example/tools/conv_from_yaml.rb +8 -0
  38. data/example/tools/conv_from_yaml_to_csv.rb +8 -0
  39. data/example/tools/conv_to_yaml.rb +17 -0
  40. data/example/tools/ellipsoid.gnuplot +63 -0
  41. data/example/tools/ellipsoid.gnuplot~ +64 -0
  42. data/example/tools/histogram.rb +19 -0
  43. data/example/tools/histogram2d.rb +20 -0
  44. data/example/tools/histogram2d.rb~ +18 -0
  45. data/example/tools/kmeans.rb +34 -0
  46. data/example/tools/mean.rb +19 -0
  47. data/example/tools/table.data +4618 -0
  48. data/example/tools/tmp.txt +69632 -0
  49. data/example/tools/xmeans.R +608 -0
  50. data/example/tools/xmeans.rb +35 -0
  51. data/lib/em_algorithm/.base.rb.swp +0 -0
  52. data/lib/em_algorithm/base.rb +116 -0
  53. data/lib/em_algorithm/base.rb~ +116 -0
  54. data/lib/em_algorithm/convergence/.chi_square.rb.swp +0 -0
  55. data/lib/em_algorithm/convergence/.likelihood.rb.swp +0 -0
  56. data/lib/em_algorithm/convergence/check_method.rb +4 -0
  57. data/lib/em_algorithm/convergence/check_method.rb~ +0 -0
  58. data/lib/em_algorithm/convergence/chi_square.rb +40 -0
  59. data/lib/em_algorithm/convergence/chi_square.rb~ +40 -0
  60. data/lib/em_algorithm/convergence/likelihood.rb +35 -0
  61. data/lib/em_algorithm/convergence/likelihood.rb~ +35 -0
  62. data/lib/em_algorithm/models/.gaussian.rb.swp +0 -0
  63. data/lib/em_algorithm/models/.md_gaussian.rb.swp +0 -0
  64. data/lib/em_algorithm/models/.mixture.rb.swp +0 -0
  65. data/lib/em_algorithm/models/.model.rb.swp +0 -0
  66. data/lib/em_algorithm/models/gaussian.rb +47 -0
  67. data/lib/em_algorithm/models/gaussian.rb~ +47 -0
  68. data/lib/em_algorithm/models/md_gaussian.rb +67 -0
  69. data/lib/em_algorithm/models/md_gaussian.rb~ +67 -0
  70. data/lib/em_algorithm/models/mixture.rb +122 -0
  71. data/lib/em_algorithm/models/mixture.rb~ +122 -0
  72. data/lib/em_algorithm/models/model.rb +19 -0
  73. data/lib/em_algorithm/models/model.rb~ +19 -0
  74. data/lib/ruby-em_algorithm.rb +3 -0
  75. data/lib/ruby-em_algorithm/version.rb +3 -0
  76. data/ruby-em_algorithm.gemspec +21 -0
  77. data/spec/spec_helper.rb +9 -0
  78. metadata +178 -0
@@ -0,0 +1,47 @@
1
+ module EMAlgorithm
2
+ class Gaussian < Model
3
+ attr_accessor :mu, :sigma, :dim
4
+
5
+ def initialize(mu = 0.0, sigma = 1.0)
6
+ @mu = mu
7
+ @sigma = sigma
8
+ end
9
+
10
+ def probability_density_function(x)
11
+ exp(-((x-@mu)**2.0)/(2.0*@sigma**2))/(sqrt(2.0*PI)*@sigma)
12
+ end
13
+
14
+ def probability_density_function_with_observation_weight(x_with_weight)
15
+ x = x_with_weight[0]
16
+ observation_weight = x_with_weight[1]
17
+ observation_weight * probability_density_function(x)
18
+ end
19
+
20
+ def update_average!(data_array, temp_weight, temp_weight_per_datum)
21
+ data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
22
+ sum + temp_weight_per_datum[di] * data_array[di]
23
+ end
24
+ @mu = data_sum / temp_weight
25
+ end
26
+
27
+ def update_sigma!(data_array, temp_weight, temp_weight_per_datum)
28
+ data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
29
+ sum + temp_weight_per_datum[di] * (data_array[di] - @mu) ** 2
30
+ end
31
+ @sigma = sqrt(data_sum / temp_weight)
32
+ end
33
+
34
+ def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
35
+ update_average!(data_array, temp_weight, temp_weight_per_datum)
36
+ update_sigma!(data_array, temp_weight, temp_weight_per_datum)
37
+ end
38
+
39
+ def to_gnuplot
40
+ "exp(-((x-(#{@mu.round(DIGIT)}))**2.0)/(2.0*#{(@sigma ** 2).round(DIGIT)}))/(sqrt(2.0*pi)*#{@sigma.round(DIGIT)})"
41
+ end
42
+
43
+ def to_gnuplot_with_title(weight)
44
+ to_gnuplot + " w l axis x1y2 lw 3 title '#{weight.round(DIGIT)}*N(#{@mu.round(DIGIT)},#{@sigma.round(DIGIT)})'"
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,47 @@
1
+ module EMAlgorithm
2
+ class Gaussian < Model
3
+ attr_accessor :mu, :sigma, :dim
4
+
5
+ def initialize(mu = 0.0, sigma = 1.0)
6
+ @mu = mu
7
+ @sigma = sigma
8
+ end
9
+
10
+ def probability_density_function(x)
11
+ exp(-((x-@mu)**2.0)/(2.0*@sigma**2))/(sqrt(2.0*PI)*@sigma)
12
+ end
13
+
14
+ def probability_density_function_with_observation_weight(x_with_weight)
15
+ x = x_with_weight[0]
16
+ observation_weight = x_with_weight[1]
17
+ observation_weight * probability_density_function(x)
18
+ end
19
+
20
+ def update_average!(data_array, temp_weight, temp_weight_per_datum)
21
+ data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
22
+ sum + temp_weight_per_datum[di] * data_array[di]
23
+ end
24
+ @mu = data_sum / temp_weight
25
+ end
26
+
27
+ def update_sigma!(data_array, temp_weight, temp_weight_per_datum)
28
+ data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
29
+ sum + temp_weight_per_datum[di] * (data_array[di] - @mu) ** 2
30
+ end
31
+ @sigma = sqrt(data_sum / temp_weight)
32
+ end
33
+
34
+ def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
35
+ update_average!(data_array, temp_weight, temp_weight_per_datum)
36
+ update_sigma!(data_array, temp_weight, temp_weight_per_datum)
37
+ end
38
+
39
+ def to_gnuplot
40
+ "exp(-((x-(#{@mu.round(DIGIT)}))**2.0)/(2.0*#{(@sigma ** 2).round(DIGIT)}))/(sqrt(2.0*pi)*#{@sigma.round(DIGIT)})"
41
+ end
42
+
43
+ def to_gnuplot_with_title(weight)
44
+ to_gnuplot + " w l axis x1y2 lw 3 title '#{weight.round(DIGIT)}*N(#{@mu.round(DIGIT)},#{@sigma.round(DIGIT)})'"
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,67 @@
1
+ module EMAlgorithm
2
+ class MdGaussian < Model
3
+ attr_accessor :mu, :sigma2
4
+
5
+ def initialize(mu = GSL::Vector[0.0, 0.0], sigma2 = GSL::Matrix[[1.0, 0.0], [0.0, 1.0]])
6
+ # check mu
7
+ if mu.class != GSL::Vector
8
+ raise ArgumentError, "mu should be GSL::Vector."
9
+ end
10
+ @mu = mu
11
+ # check sigma2
12
+ if sigma2.class != GSL::Matrix
13
+ raise ArgumentError, "sigma2 should be GSL::Matrix."
14
+ elsif sigma2.size1 != @mu.size || sigma2.size2 != @mu.size
15
+ raise ArgumentError, "The size of sigma2 matrix does not match with mu vector."
16
+ end
17
+ @sigma2 = sigma2
18
+ @sqrt_sigma2_det = sqrt(@sigma2.det)
19
+ @sigma2_invert = @sigma2.invert
20
+ end
21
+
22
+ def probability_density_function(x)
23
+ exp(-((x-@mu) * @sigma2_invert * (x-@mu).trans)/2.0)/((sqrt(2.0*PI)**@mu.size)*@sqrt_sigma2_det)
24
+ end
25
+
26
+ def update_average!(data_array, temp_weight, temp_weight_per_datum)
27
+ data_sum = (0..(data_array.size-1)).inject(GSL::Vector.alloc(@mu.size).set_zero) do |sum, di|
28
+ sum + temp_weight_per_datum[di] * data_array[di]
29
+ end
30
+ @mu = data_sum / temp_weight
31
+ end
32
+
33
+ def update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
34
+ data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
35
+ sum + temp_weight_per_datum[di] * (data_array[di] - @mu).trans * (data_array[di] - @mu)
36
+ end
37
+ @sigma2 = (data_sum / temp_weight)
38
+ @sqrt_sigma2_det = sqrt(@sigma2.det)
39
+ @sigma2_invert = @sigma2.invert
40
+ end
41
+
42
+ def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
43
+ update_average!(data_array, temp_weight, temp_weight_per_datum)
44
+ update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
45
+ end
46
+
47
+ def to_gnuplot
48
+ if @mu.size == 2
49
+ # [x - mu_x, y - mu_y] * [[s_x, s_xy], [s_xy, s_y]] * [x - mu_x, y - mu_y]
50
+ # = s_x*(x - mu_x)**2 + 2*s_xy*(x - mu_x)(y - mu_y) + s_y*(y - mu_y)**2
51
+ sigma2_xy = @sigma2[0,1] + @sigma2[1,0]
52
+ xy = "+(#{sigma2_xy.round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))*(y-(#{@mu[1].round((DIGIT))}))" if sigma2_xy > 0 || sigma2_xy < 0
53
+ "exp(-((#{@sigma2[0,0].round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))**2.0+(#{@sigma2[1,1].round((DIGIT))})*(y-(#{@mu[1].round((DIGIT))}))**2.0#{xy})/2.0)/((sqrt(2.0*pi))**#{@mu.size}*(#{@sqrt_sigma2_det.round((DIGIT))}))"
54
+ else
55
+ "N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
56
+ end
57
+ end
58
+
59
+ def to_gnuplot_with_title(weight)
60
+ if @mu.size == 2
61
+ to_gnuplot + " w l lw 3 title '#{weight.round((DIGIT))}*N(#{@mu.map{|mu| mu.round((DIGIT))}.to_a.inspect},#{@sigma2.map{|sigma2| sigma2.round((DIGIT))}.to_a.inspect})'"
62
+ else
63
+ "N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
64
+ end
65
+ end
66
+ end
67
+ end
@@ -0,0 +1,67 @@
1
+ module EMAlgorithm
2
+ class MdGaussian < Model
3
+ attr_accessor :mu, :sigma2
4
+
5
+ def initialize(mu = GSL::Vector[0.0, 0.0], sigma2 = GSL::Matrix[[1.0, 0.0], [0.0, 1.0]])
6
+ # check mu
7
+ if mu.class != GSL::Vector
8
+ raise ArgumentError, "mu should be GSL::Vector."
9
+ end
10
+ @mu = mu
11
+ # check sigma2
12
+ if sigma2.class != GSL::Matrix
13
+ raise ArgumentError, "sigma2 should be GSL::Matrix."
14
+ elsif sigma2.size1 != @mu.size || sigma2.size2 != @mu.size
15
+ raise ArgumentError, "The size of sigma2 matrix does not match with mu vector."
16
+ end
17
+ @sigma2 = sigma2
18
+ @sqrt_sigma2_det = sqrt(@sigma2.det)
19
+ @sigma2_invert = @sigma2.invert
20
+ end
21
+
22
+ def probability_density_function(x)
23
+ exp(-((x-@mu) * @sigma2_invert * (x-@mu).trans)/2.0)/((sqrt(2.0*PI)**@mu.size)*@sqrt_sigma2_det)
24
+ end
25
+
26
+ def update_average!(data_array, temp_weight, temp_weight_per_datum)
27
+ data_sum = (0..(data_array.size-1)).inject(GSL::Vector.alloc(@mu.size).set_zero) do |sum, di|
28
+ sum + temp_weight_per_datum[di] * data_array[di]
29
+ end
30
+ @mu = data_sum / temp_weight
31
+ end
32
+
33
+ def update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
34
+ data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
35
+ sum + temp_weight_per_datum[di] * (data_array[di] - @mu).trans * (data_array[di] - @mu)
36
+ end
37
+ @sigma2 = (data_sum / temp_weight)
38
+ @sqrt_sigma2_det = sqrt(@sigma2.det)
39
+ @sigma2_invert = @sigma2.invert
40
+ end
41
+
42
+ def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
43
+ update_average!(data_array, temp_weight, temp_weight_per_datum)
44
+ update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
45
+ end
46
+
47
+ def to_gnuplot
48
+ if @mu.size == 2
49
+ # [x - mu_x, y - mu_y] * [[s_x, s_xy], [s_xy, s_y]] * [x - mu_x, y - mu_y]
50
+ # = s_x*(x - mu_x)**2 + 2*s_xy*(x - mu_x)(y - mu_y) + s_y*(y - mu_y)**2
51
+ sigma2_xy = @sigma2[0,1] + @sigma2[1,0]
52
+ xy = "+(#{sigma2_xy.round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))*(y-(#{@mu[1].round((DIGIT))}))" if sigma2_xy > 0 || sigma2_xy < 0
53
+ "exp(-((#{@sigma2[0,0].round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))**2.0+(#{@sigma2[1,1].round((DIGIT))})*(y-(#{@mu[1].round((DIGIT))}))**2.0#{xy})/2.0)/((sqrt(2.0*pi))**#{@mu.size}*(#{@sqrt_sigma2_det.round((DIGIT))}))"
54
+ else
55
+ "N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
56
+ end
57
+ end
58
+
59
+ def to_gnuplot_with_title(weight)
60
+ if @mu.size == 2
61
+ to_gnuplot + " w l lw 3 title '#{weight.round((DIGIT))}*N(#{@mu.map{|mu| mu.round((DIGIT))}.to_a.inspect},#{@sigma2.map{|sigma2| sigma2.round((DIGIT))}.to_a.inspect})'"
62
+ else
63
+ "N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
64
+ end
65
+ end
66
+ end
67
+ end
@@ -0,0 +1,122 @@
1
+ module EMAlgorithm
2
+ class Mixture < Model
3
+ attr_accessor :models, :weights
4
+
5
+ def initialize(options)
6
+ opts = {
7
+ :models =>
8
+ [
9
+ Gaussian.new(0.0, 9.0), Gaussian.new(10.0, 9.0)
10
+ ],
11
+ :weights =>
12
+ [
13
+ 0.5, 0.5
14
+ ]
15
+ }.merge(options)
16
+ @models = opts[:models]
17
+ @weights = opts[:weights]
18
+ if !proper_weights?
19
+ argument_error
20
+ end
21
+ @temp_weights = Array.new(@models.size)
22
+ @temp_weight_per_datum = Array.new(@models.size).map { Array.new }
23
+ end
24
+
25
+ def argument_error
26
+ raise ArgumentError, "The summation of @weights must be equal to 1.0."
27
+ end
28
+
29
+ def add(model = Gaussian.new(0.0, 9.0), weight = 0.0)
30
+ @models << model
31
+ @weights << weight
32
+ if !proper_weights?
33
+ argument_error
34
+ end
35
+ end
36
+
37
+ def proper_weights?
38
+ @weights.inject(0) {|sum, v| sum + v} == 1.0
39
+ end
40
+
41
+ def probability_density_function(x)
42
+ pdf = 0.0
43
+ @models.each_with_index do |model, mi|
44
+ pdf += model.pdf(x) * @weights[mi]
45
+ end
46
+ pdf
47
+ end
48
+
49
+ def clear_temp_weight_per_datum!
50
+ @temp_weight_per_datum.each {|w| w.clear}
51
+ end
52
+
53
+ def calculate_posterior_data_array(data_array)
54
+ posterior_data_array = Array.new(data_array.size, 0.0)
55
+ @models.each_with_index do |model, mi|
56
+ data_array.each_with_index do |x, di|
57
+ posterior_data_array[di] += @weights[mi] * model.pdf(x)
58
+ end
59
+ end
60
+ posterior_data_array
61
+ end
62
+
63
+ def update_temp_weights!(data_array, posterior_data_array)
64
+ @models.each_with_index do |model, mi|
65
+ data_array.each_with_index do |x, di|
66
+ temp_weight_per_datum = @weights[mi] * model.pdf(x) / posterior_data_array[di]
67
+ temp_weight_per_datum = 0.0 if temp_weight_per_datum.nan?
68
+ @temp_weight_per_datum[mi] << temp_weight_per_datum
69
+ end
70
+ @temp_weights[mi] = @temp_weight_per_datum[mi].inject(0.0) {|sum, w| sum + w}
71
+ end
72
+ end
73
+
74
+ def update_weights!(data_array)
75
+ (0..(@models.size-1)).each do |mi|
76
+ @weights[mi] = @temp_weights[mi] / data_array.size
77
+ end
78
+ end
79
+
80
+ def update_parameters!(data_array)
81
+ @models.each_with_index do |model, mi|
82
+ model.update_parameters!(data_array, @temp_weights[mi], @temp_weight_per_datum[mi])
83
+ end
84
+ update_weights!(data_array)
85
+ end
86
+
87
+ # output types
88
+ # :full (default)
89
+ # :separate_only
90
+ # :mixture_only
91
+ def to_gnuplot(type = :full)
92
+ # output each model (currently assume Gaussian)
93
+ output = []
94
+ @models.each_with_index do |model, mi|
95
+ output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot_with_title(@weights[mi])}"
96
+ end
97
+ separate = output.join(", ")
98
+ # output mixture model (currently assume Gaussian Mixture model)
99
+ output = []
100
+ @models.each_with_index do |model, mi|
101
+ output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot}"
102
+ end
103
+ mixture = output.join(" + ")
104
+ case type
105
+ when :separate_only
106
+ return separate
107
+ when :mixture_only
108
+ return mixture
109
+ end
110
+ "#{separate}, #{mixture}"
111
+ end
112
+
113
+ def debug_output
114
+ <<-DEBUG_OUT
115
+ @weights=#{@weights.inspect}
116
+ @temp_weights=#{@temp_weights.inspect}
117
+ @models
118
+ #{@models.inspect}
119
+ DEBUG_OUT
120
+ end
121
+ end
122
+ end
@@ -0,0 +1,122 @@
1
+ module EMAlgorithm
2
+ class Mixture < Model
3
+ attr_accessor :models, :weights, :const
4
+
5
+ def initialize(options)
6
+ opts = {
7
+ :models =>
8
+ [
9
+ Gaussian.new(0.0, 9.0), Gaussian.new(10.0, 9.0)
10
+ ],
11
+ :weights =>
12
+ [
13
+ 0.5, 0.5
14
+ ]
15
+ }.merge(options)
16
+ @models = opts[:models]
17
+ @weights = opts[:weights]
18
+ if !proper_weights?
19
+ argument_error
20
+ end
21
+ @temp_weights = Array.new(@models.size)
22
+ @temp_weight_per_datum = Array.new(@models.size).map { Array.new }
23
+ end
24
+
25
+ def argument_error
26
+ raise ArgumentError, "The summation of @weights must be equal to 1.0."
27
+ end
28
+
29
+ def add(model = Gaussian.new(0.0, 9.0), weight = 0.0)
30
+ @models << model
31
+ @weights << weight
32
+ if !proper_weights?
33
+ argument_error
34
+ end
35
+ end
36
+
37
+ def proper_weights?
38
+ @weights.inject(0) {|sum, v| sum + v} == 1.0
39
+ end
40
+
41
+ def probability_density_function(x)
42
+ pdf = 0.0
43
+ @models.each_with_index do |model, mi|
44
+ pdf += model.pdf(x) * @weights[mi]
45
+ end
46
+ pdf
47
+ end
48
+
49
+ def clear_temp_weight_per_datum!
50
+ @temp_weight_per_datum.each {|w| w.clear}
51
+ end
52
+
53
+ def calculate_posterior_data_array(data_array)
54
+ posterior_data_array = Array.new(data_array.size, 0.0)
55
+ @models.each_with_index do |model, mi|
56
+ data_array.each_with_index do |x, di|
57
+ posterior_data_array[di] += @weights[mi] * model.pdf(x)
58
+ end
59
+ end
60
+ posterior_data_array
61
+ end
62
+
63
+ def update_temp_weights!(data_array, posterior_data_array)
64
+ @models.each_with_index do |model, mi|
65
+ data_array.each_with_index do |x, di|
66
+ temp_weight_per_datum = @weights[mi] * model.pdf(x) / posterior_data_array[di]
67
+ temp_weight_per_datum = 0.0 if temp_weight_per_datum.nan?
68
+ @temp_weight_per_datum[mi] << temp_weight_per_datum
69
+ end
70
+ @temp_weights[mi] = @temp_weight_per_datum[mi].inject(0.0) {|sum, w| sum + w}
71
+ end
72
+ end
73
+
74
+ def update_weights!(data_array)
75
+ (0..(@models.size-1)).each do |mi|
76
+ @weights[mi] = @temp_weights[mi] / data_array.size
77
+ end
78
+ end
79
+
80
+ def update_parameters!(data_array)
81
+ @models.each_with_index do |model, mi|
82
+ model.update_parameters!(data_array, @temp_weights[mi], @temp_weight_per_datum[mi])
83
+ end
84
+ update_weights!(data_array)
85
+ end
86
+
87
+ # output types
88
+ # :full (default)
89
+ # :separate_only
90
+ # :mixture_only
91
+ def to_gnuplot(type = :full)
92
+ # output each model (currently assume Gaussian)
93
+ output = []
94
+ @models.each_with_index do |model, mi|
95
+ output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot_with_title(@weights[mi])}"
96
+ end
97
+ separate = output.join(", ")
98
+ # output mixture model (currently assume Gaussian Mixture model)
99
+ output = []
100
+ @models.each_with_index do |model, mi|
101
+ output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot}"
102
+ end
103
+ mixture = output.join(" + ")
104
+ case type
105
+ when :separate_only
106
+ return separate
107
+ when :mixture_only
108
+ return mixture
109
+ end
110
+ "#{separate}, #{mixture}"
111
+ end
112
+
113
+ def debug_output
114
+ <<-DEBUG_OUT
115
+ @weights=#{@weights.inspect}
116
+ @temp_weights=#{@temp_weights.inspect}
117
+ @models
118
+ #{@models.inspect}
119
+ DEBUG_OUT
120
+ end
121
+ end
122
+ end