ruby-em_algorithm 0.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. data/Gemfile +6 -0
  2. data/Gemfile.lock +30 -0
  3. data/README.md +44 -0
  4. data/Rakefile +7 -0
  5. data/example/.ex1.rb.swp +0 -0
  6. data/example/.ex2.rb.swp +0 -0
  7. data/example/.ex3-tmp.rb.swp +0 -0
  8. data/example/.ex3.rb.swp +0 -0
  9. data/example/data/2dim-gmm-new.txt +1267 -0
  10. data/example/data/2dim-gmm-simple.txt +676 -0
  11. data/example/data/2dim-gmm-test.txt +6565 -0
  12. data/example/data/2dim-gmm-test2.txt +2782 -0
  13. data/example/data/2dim-gmm-test3.csv +1641 -0
  14. data/example/data/2dim-gmm-test3.txt +2782 -0
  15. data/example/data/2dim-gmm-test4.csv +868 -0
  16. data/example/data/2dim-gmm-test4.txt +4924 -0
  17. data/example/data/2dim-gmm-without_weight-small.txt +2401 -0
  18. data/example/data/2dim-gmm-without_weight.txt +18001 -0
  19. data/example/data/2dim-gmm.txt +1267 -0
  20. data/example/data/gmm-new.txt +10001 -0
  21. data/example/data/gmm-simple.txt +676 -0
  22. data/example/data/gmm.txt +10001 -0
  23. data/example/data/old-gmm.txt +10000 -0
  24. data/example/ex1.rb +20 -0
  25. data/example/ex1.rb~ +20 -0
  26. data/example/ex2.rb +33 -0
  27. data/example/ex2.rb~ +33 -0
  28. data/example/ex3-tmp.rb +23 -0
  29. data/example/ex3-tmp.rb~ +25 -0
  30. data/example/ex3.rb +43 -0
  31. data/example/ex3.rb~ +43 -0
  32. data/example/tools/.2dim.rb.swp +0 -0
  33. data/example/tools/2dim.rb +69 -0
  34. data/example/tools/2dim.rb~ +69 -0
  35. data/example/tools/boxmuller.rb +28 -0
  36. data/example/tools/boxmuller.rb~ +28 -0
  37. data/example/tools/conv_from_yaml.rb +8 -0
  38. data/example/tools/conv_from_yaml_to_csv.rb +8 -0
  39. data/example/tools/conv_to_yaml.rb +17 -0
  40. data/example/tools/ellipsoid.gnuplot +63 -0
  41. data/example/tools/ellipsoid.gnuplot~ +64 -0
  42. data/example/tools/histogram.rb +19 -0
  43. data/example/tools/histogram2d.rb +20 -0
  44. data/example/tools/histogram2d.rb~ +18 -0
  45. data/example/tools/kmeans.rb +34 -0
  46. data/example/tools/mean.rb +19 -0
  47. data/example/tools/table.data +4618 -0
  48. data/example/tools/tmp.txt +69632 -0
  49. data/example/tools/xmeans.R +608 -0
  50. data/example/tools/xmeans.rb +35 -0
  51. data/lib/em_algorithm/.base.rb.swp +0 -0
  52. data/lib/em_algorithm/base.rb +116 -0
  53. data/lib/em_algorithm/base.rb~ +116 -0
  54. data/lib/em_algorithm/convergence/.chi_square.rb.swp +0 -0
  55. data/lib/em_algorithm/convergence/.likelihood.rb.swp +0 -0
  56. data/lib/em_algorithm/convergence/check_method.rb +4 -0
  57. data/lib/em_algorithm/convergence/check_method.rb~ +0 -0
  58. data/lib/em_algorithm/convergence/chi_square.rb +40 -0
  59. data/lib/em_algorithm/convergence/chi_square.rb~ +40 -0
  60. data/lib/em_algorithm/convergence/likelihood.rb +35 -0
  61. data/lib/em_algorithm/convergence/likelihood.rb~ +35 -0
  62. data/lib/em_algorithm/models/.gaussian.rb.swp +0 -0
  63. data/lib/em_algorithm/models/.md_gaussian.rb.swp +0 -0
  64. data/lib/em_algorithm/models/.mixture.rb.swp +0 -0
  65. data/lib/em_algorithm/models/.model.rb.swp +0 -0
  66. data/lib/em_algorithm/models/gaussian.rb +47 -0
  67. data/lib/em_algorithm/models/gaussian.rb~ +47 -0
  68. data/lib/em_algorithm/models/md_gaussian.rb +67 -0
  69. data/lib/em_algorithm/models/md_gaussian.rb~ +67 -0
  70. data/lib/em_algorithm/models/mixture.rb +122 -0
  71. data/lib/em_algorithm/models/mixture.rb~ +122 -0
  72. data/lib/em_algorithm/models/model.rb +19 -0
  73. data/lib/em_algorithm/models/model.rb~ +19 -0
  74. data/lib/ruby-em_algorithm.rb +3 -0
  75. data/lib/ruby-em_algorithm/version.rb +3 -0
  76. data/ruby-em_algorithm.gemspec +21 -0
  77. data/spec/spec_helper.rb +9 -0
  78. metadata +178 -0
@@ -0,0 +1,47 @@
1
+ module EMAlgorithm
2
+ class Gaussian < Model
3
+ attr_accessor :mu, :sigma, :dim
4
+
5
+ def initialize(mu = 0.0, sigma = 1.0)
6
+ @mu = mu
7
+ @sigma = sigma
8
+ end
9
+
10
+ def probability_density_function(x)
11
+ exp(-((x-@mu)**2.0)/(2.0*@sigma**2))/(sqrt(2.0*PI)*@sigma)
12
+ end
13
+
14
+ def probability_density_function_with_observation_weight(x_with_weight)
15
+ x = x_with_weight[0]
16
+ observation_weight = x_with_weight[1]
17
+ observation_weight * probability_density_function(x)
18
+ end
19
+
20
+ def update_average!(data_array, temp_weight, temp_weight_per_datum)
21
+ data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
22
+ sum + temp_weight_per_datum[di] * data_array[di]
23
+ end
24
+ @mu = data_sum / temp_weight
25
+ end
26
+
27
+ def update_sigma!(data_array, temp_weight, temp_weight_per_datum)
28
+ data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
29
+ sum + temp_weight_per_datum[di] * (data_array[di] - @mu) ** 2
30
+ end
31
+ @sigma = sqrt(data_sum / temp_weight)
32
+ end
33
+
34
+ def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
35
+ update_average!(data_array, temp_weight, temp_weight_per_datum)
36
+ update_sigma!(data_array, temp_weight, temp_weight_per_datum)
37
+ end
38
+
39
+ def to_gnuplot
40
+ "exp(-((x-(#{@mu.round(DIGIT)}))**2.0)/(2.0*#{(@sigma ** 2).round(DIGIT)}))/(sqrt(2.0*pi)*#{@sigma.round(DIGIT)})"
41
+ end
42
+
43
+ def to_gnuplot_with_title(weight)
44
+ to_gnuplot + " w l axis x1y2 lw 3 title '#{weight.round(DIGIT)}*N(#{@mu.round(DIGIT)},#{@sigma.round(DIGIT)})'"
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,47 @@
1
+ module EMAlgorithm
2
+ class Gaussian < Model
3
+ attr_accessor :mu, :sigma, :dim
4
+
5
+ def initialize(mu = 0.0, sigma = 1.0)
6
+ @mu = mu
7
+ @sigma = sigma
8
+ end
9
+
10
+ def probability_density_function(x)
11
+ exp(-((x-@mu)**2.0)/(2.0*@sigma**2))/(sqrt(2.0*PI)*@sigma)
12
+ end
13
+
14
+ def probability_density_function_with_observation_weight(x_with_weight)
15
+ x = x_with_weight[0]
16
+ observation_weight = x_with_weight[1]
17
+ observation_weight * probability_density_function(x)
18
+ end
19
+
20
+ def update_average!(data_array, temp_weight, temp_weight_per_datum)
21
+ data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
22
+ sum + temp_weight_per_datum[di] * data_array[di]
23
+ end
24
+ @mu = data_sum / temp_weight
25
+ end
26
+
27
+ def update_sigma!(data_array, temp_weight, temp_weight_per_datum)
28
+ data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
29
+ sum + temp_weight_per_datum[di] * (data_array[di] - @mu) ** 2
30
+ end
31
+ @sigma = sqrt(data_sum / temp_weight)
32
+ end
33
+
34
+ def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
35
+ update_average!(data_array, temp_weight, temp_weight_per_datum)
36
+ update_sigma!(data_array, temp_weight, temp_weight_per_datum)
37
+ end
38
+
39
+ def to_gnuplot
40
+ "exp(-((x-(#{@mu.round(DIGIT)}))**2.0)/(2.0*#{(@sigma ** 2).round(DIGIT)}))/(sqrt(2.0*pi)*#{@sigma.round(DIGIT)})"
41
+ end
42
+
43
+ def to_gnuplot_with_title(weight)
44
+ to_gnuplot + " w l axis x1y2 lw 3 title '#{weight.round(DIGIT)}*N(#{@mu.round(DIGIT)},#{@sigma.round(DIGIT)})'"
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,67 @@
1
+ module EMAlgorithm
2
+ class MdGaussian < Model
3
+ attr_accessor :mu, :sigma2
4
+
5
+ def initialize(mu = GSL::Vector[0.0, 0.0], sigma2 = GSL::Matrix[[1.0, 0.0], [0.0, 1.0]])
6
+ # check mu
7
+ if mu.class != GSL::Vector
8
+ raise ArgumentError, "mu should be GSL::Vector."
9
+ end
10
+ @mu = mu
11
+ # check sigma2
12
+ if sigma2.class != GSL::Matrix
13
+ raise ArgumentError, "sigma2 should be GSL::Matrix."
14
+ elsif sigma2.size1 != @mu.size || sigma2.size2 != @mu.size
15
+ raise ArgumentError, "The size of sigma2 matrix does not match with mu vector."
16
+ end
17
+ @sigma2 = sigma2
18
+ @sqrt_sigma2_det = sqrt(@sigma2.det)
19
+ @sigma2_invert = @sigma2.invert
20
+ end
21
+
22
+ def probability_density_function(x)
23
+ exp(-((x-@mu) * @sigma2_invert * (x-@mu).trans)/2.0)/((sqrt(2.0*PI)**@mu.size)*@sqrt_sigma2_det)
24
+ end
25
+
26
+ def update_average!(data_array, temp_weight, temp_weight_per_datum)
27
+ data_sum = (0..(data_array.size-1)).inject(GSL::Vector.alloc(@mu.size).set_zero) do |sum, di|
28
+ sum + temp_weight_per_datum[di] * data_array[di]
29
+ end
30
+ @mu = data_sum / temp_weight
31
+ end
32
+
33
+ def update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
34
+ data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
35
+ sum + temp_weight_per_datum[di] * (data_array[di] - @mu).trans * (data_array[di] - @mu)
36
+ end
37
+ @sigma2 = (data_sum / temp_weight)
38
+ @sqrt_sigma2_det = sqrt(@sigma2.det)
39
+ @sigma2_invert = @sigma2.invert
40
+ end
41
+
42
+ def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
43
+ update_average!(data_array, temp_weight, temp_weight_per_datum)
44
+ update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
45
+ end
46
+
47
+ def to_gnuplot
48
+ if @mu.size == 2
49
+ # [x - mu_x, y - mu_y] * [[s_x, s_xy], [s_xy, s_y]] * [x - mu_x, y - mu_y]
50
+ # = s_x*(x - mu_x)**2 + 2*s_xy*(x - mu_x)(y - mu_y) + s_y*(y - mu_y)**2
51
+ sigma2_xy = @sigma2[0,1] + @sigma2[1,0]
52
+ xy = "+(#{sigma2_xy.round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))*(y-(#{@mu[1].round((DIGIT))}))" if sigma2_xy > 0 || sigma2_xy < 0
53
+ "exp(-((#{@sigma2[0,0].round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))**2.0+(#{@sigma2[1,1].round((DIGIT))})*(y-(#{@mu[1].round((DIGIT))}))**2.0#{xy})/2.0)/((sqrt(2.0*pi))**#{@mu.size}*(#{@sqrt_sigma2_det.round((DIGIT))}))"
54
+ else
55
+ "N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
56
+ end
57
+ end
58
+
59
+ def to_gnuplot_with_title(weight)
60
+ if @mu.size == 2
61
+ to_gnuplot + " w l lw 3 title '#{weight.round((DIGIT))}*N(#{@mu.map{|mu| mu.round((DIGIT))}.to_a.inspect},#{@sigma2.map{|sigma2| sigma2.round((DIGIT))}.to_a.inspect})'"
62
+ else
63
+ "N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
64
+ end
65
+ end
66
+ end
67
+ end
@@ -0,0 +1,67 @@
1
+ module EMAlgorithm
2
+ class MdGaussian < Model
3
+ attr_accessor :mu, :sigma2
4
+
5
+ def initialize(mu = GSL::Vector[0.0, 0.0], sigma2 = GSL::Matrix[[1.0, 0.0], [0.0, 1.0]])
6
+ # check mu
7
+ if mu.class != GSL::Vector
8
+ raise ArgumentError, "mu should be GSL::Vector."
9
+ end
10
+ @mu = mu
11
+ # check sigma2
12
+ if sigma2.class != GSL::Matrix
13
+ raise ArgumentError, "sigma2 should be GSL::Matrix."
14
+ elsif sigma2.size1 != @mu.size || sigma2.size2 != @mu.size
15
+ raise ArgumentError, "The size of sigma2 matrix does not match with mu vector."
16
+ end
17
+ @sigma2 = sigma2
18
+ @sqrt_sigma2_det = sqrt(@sigma2.det)
19
+ @sigma2_invert = @sigma2.invert
20
+ end
21
+
22
+ def probability_density_function(x)
23
+ exp(-((x-@mu) * @sigma2_invert * (x-@mu).trans)/2.0)/((sqrt(2.0*PI)**@mu.size)*@sqrt_sigma2_det)
24
+ end
25
+
26
+ def update_average!(data_array, temp_weight, temp_weight_per_datum)
27
+ data_sum = (0..(data_array.size-1)).inject(GSL::Vector.alloc(@mu.size).set_zero) do |sum, di|
28
+ sum + temp_weight_per_datum[di] * data_array[di]
29
+ end
30
+ @mu = data_sum / temp_weight
31
+ end
32
+
33
+ def update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
34
+ data_sum = (0..(data_array.size-1)).inject(0.0) do |sum, di|
35
+ sum + temp_weight_per_datum[di] * (data_array[di] - @mu).trans * (data_array[di] - @mu)
36
+ end
37
+ @sigma2 = (data_sum / temp_weight)
38
+ @sqrt_sigma2_det = sqrt(@sigma2.det)
39
+ @sigma2_invert = @sigma2.invert
40
+ end
41
+
42
+ def update_parameters!(data_array, temp_weight, temp_weight_per_datum)
43
+ update_average!(data_array, temp_weight, temp_weight_per_datum)
44
+ update_sigma2!(data_array, temp_weight, temp_weight_per_datum)
45
+ end
46
+
47
+ def to_gnuplot
48
+ if @mu.size == 2
49
+ # [x - mu_x, y - mu_y] * [[s_x, s_xy], [s_xy, s_y]] * [x - mu_x, y - mu_y]
50
+ # = s_x*(x - mu_x)**2 + 2*s_xy*(x - mu_x)(y - mu_y) + s_y*(y - mu_y)**2
51
+ sigma2_xy = @sigma2[0,1] + @sigma2[1,0]
52
+ xy = "+(#{sigma2_xy.round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))*(y-(#{@mu[1].round((DIGIT))}))" if sigma2_xy > 0 || sigma2_xy < 0
53
+ "exp(-((#{@sigma2[0,0].round((DIGIT))})*(x-(#{@mu[0].round((DIGIT))}))**2.0+(#{@sigma2[1,1].round((DIGIT))})*(y-(#{@mu[1].round((DIGIT))}))**2.0#{xy})/2.0)/((sqrt(2.0*pi))**#{@mu.size}*(#{@sqrt_sigma2_det.round((DIGIT))}))"
54
+ else
55
+ "N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
56
+ end
57
+ end
58
+
59
+ def to_gnuplot_with_title(weight)
60
+ if @mu.size == 2
61
+ to_gnuplot + " w l lw 3 title '#{weight.round((DIGIT))}*N(#{@mu.map{|mu| mu.round((DIGIT))}.to_a.inspect},#{@sigma2.map{|sigma2| sigma2.round((DIGIT))}.to_a.inspect})'"
62
+ else
63
+ "N(#{@mu.to_a.inspect}, #{@sigma2.to_a.inspect})"
64
+ end
65
+ end
66
+ end
67
+ end
@@ -0,0 +1,122 @@
1
+ module EMAlgorithm
2
+ class Mixture < Model
3
+ attr_accessor :models, :weights
4
+
5
+ def initialize(options)
6
+ opts = {
7
+ :models =>
8
+ [
9
+ Gaussian.new(0.0, 9.0), Gaussian.new(10.0, 9.0)
10
+ ],
11
+ :weights =>
12
+ [
13
+ 0.5, 0.5
14
+ ]
15
+ }.merge(options)
16
+ @models = opts[:models]
17
+ @weights = opts[:weights]
18
+ if !proper_weights?
19
+ argument_error
20
+ end
21
+ @temp_weights = Array.new(@models.size)
22
+ @temp_weight_per_datum = Array.new(@models.size).map { Array.new }
23
+ end
24
+
25
+ def argument_error
26
+ raise ArgumentError, "The summation of @weights must be equal to 1.0."
27
+ end
28
+
29
+ def add(model = Gaussian.new(0.0, 9.0), weight = 0.0)
30
+ @models << model
31
+ @weights << weight
32
+ if !proper_weights?
33
+ argument_error
34
+ end
35
+ end
36
+
37
+ def proper_weights?
38
+ @weights.inject(0) {|sum, v| sum + v} == 1.0
39
+ end
40
+
41
+ def probability_density_function(x)
42
+ pdf = 0.0
43
+ @models.each_with_index do |model, mi|
44
+ pdf += model.pdf(x) * @weights[mi]
45
+ end
46
+ pdf
47
+ end
48
+
49
+ def clear_temp_weight_per_datum!
50
+ @temp_weight_per_datum.each {|w| w.clear}
51
+ end
52
+
53
+ def calculate_posterior_data_array(data_array)
54
+ posterior_data_array = Array.new(data_array.size, 0.0)
55
+ @models.each_with_index do |model, mi|
56
+ data_array.each_with_index do |x, di|
57
+ posterior_data_array[di] += @weights[mi] * model.pdf(x)
58
+ end
59
+ end
60
+ posterior_data_array
61
+ end
62
+
63
+ def update_temp_weights!(data_array, posterior_data_array)
64
+ @models.each_with_index do |model, mi|
65
+ data_array.each_with_index do |x, di|
66
+ temp_weight_per_datum = @weights[mi] * model.pdf(x) / posterior_data_array[di]
67
+ temp_weight_per_datum = 0.0 if temp_weight_per_datum.nan?
68
+ @temp_weight_per_datum[mi] << temp_weight_per_datum
69
+ end
70
+ @temp_weights[mi] = @temp_weight_per_datum[mi].inject(0.0) {|sum, w| sum + w}
71
+ end
72
+ end
73
+
74
+ def update_weights!(data_array)
75
+ (0..(@models.size-1)).each do |mi|
76
+ @weights[mi] = @temp_weights[mi] / data_array.size
77
+ end
78
+ end
79
+
80
+ def update_parameters!(data_array)
81
+ @models.each_with_index do |model, mi|
82
+ model.update_parameters!(data_array, @temp_weights[mi], @temp_weight_per_datum[mi])
83
+ end
84
+ update_weights!(data_array)
85
+ end
86
+
87
+ # output types
88
+ # :full (default)
89
+ # :separate_only
90
+ # :mixture_only
91
+ def to_gnuplot(type = :full)
92
+ # output each model (currently assume Gaussian)
93
+ output = []
94
+ @models.each_with_index do |model, mi|
95
+ output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot_with_title(@weights[mi])}"
96
+ end
97
+ separate = output.join(", ")
98
+ # output mixture model (currently assume Gaussian Mixture model)
99
+ output = []
100
+ @models.each_with_index do |model, mi|
101
+ output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot}"
102
+ end
103
+ mixture = output.join(" + ")
104
+ case type
105
+ when :separate_only
106
+ return separate
107
+ when :mixture_only
108
+ return mixture
109
+ end
110
+ "#{separate}, #{mixture}"
111
+ end
112
+
113
+ def debug_output
114
+ <<-DEBUG_OUT
115
+ @weights=#{@weights.inspect}
116
+ @temp_weights=#{@temp_weights.inspect}
117
+ @models
118
+ #{@models.inspect}
119
+ DEBUG_OUT
120
+ end
121
+ end
122
+ end
@@ -0,0 +1,122 @@
1
+ module EMAlgorithm
2
+ class Mixture < Model
3
+ attr_accessor :models, :weights, :const
4
+
5
+ def initialize(options)
6
+ opts = {
7
+ :models =>
8
+ [
9
+ Gaussian.new(0.0, 9.0), Gaussian.new(10.0, 9.0)
10
+ ],
11
+ :weights =>
12
+ [
13
+ 0.5, 0.5
14
+ ]
15
+ }.merge(options)
16
+ @models = opts[:models]
17
+ @weights = opts[:weights]
18
+ if !proper_weights?
19
+ argument_error
20
+ end
21
+ @temp_weights = Array.new(@models.size)
22
+ @temp_weight_per_datum = Array.new(@models.size).map { Array.new }
23
+ end
24
+
25
+ def argument_error
26
+ raise ArgumentError, "The summation of @weights must be equal to 1.0."
27
+ end
28
+
29
+ def add(model = Gaussian.new(0.0, 9.0), weight = 0.0)
30
+ @models << model
31
+ @weights << weight
32
+ if !proper_weights?
33
+ argument_error
34
+ end
35
+ end
36
+
37
+ def proper_weights?
38
+ @weights.inject(0) {|sum, v| sum + v} == 1.0
39
+ end
40
+
41
+ def probability_density_function(x)
42
+ pdf = 0.0
43
+ @models.each_with_index do |model, mi|
44
+ pdf += model.pdf(x) * @weights[mi]
45
+ end
46
+ pdf
47
+ end
48
+
49
+ def clear_temp_weight_per_datum!
50
+ @temp_weight_per_datum.each {|w| w.clear}
51
+ end
52
+
53
+ def calculate_posterior_data_array(data_array)
54
+ posterior_data_array = Array.new(data_array.size, 0.0)
55
+ @models.each_with_index do |model, mi|
56
+ data_array.each_with_index do |x, di|
57
+ posterior_data_array[di] += @weights[mi] * model.pdf(x)
58
+ end
59
+ end
60
+ posterior_data_array
61
+ end
62
+
63
+ def update_temp_weights!(data_array, posterior_data_array)
64
+ @models.each_with_index do |model, mi|
65
+ data_array.each_with_index do |x, di|
66
+ temp_weight_per_datum = @weights[mi] * model.pdf(x) / posterior_data_array[di]
67
+ temp_weight_per_datum = 0.0 if temp_weight_per_datum.nan?
68
+ @temp_weight_per_datum[mi] << temp_weight_per_datum
69
+ end
70
+ @temp_weights[mi] = @temp_weight_per_datum[mi].inject(0.0) {|sum, w| sum + w}
71
+ end
72
+ end
73
+
74
+ def update_weights!(data_array)
75
+ (0..(@models.size-1)).each do |mi|
76
+ @weights[mi] = @temp_weights[mi] / data_array.size
77
+ end
78
+ end
79
+
80
+ def update_parameters!(data_array)
81
+ @models.each_with_index do |model, mi|
82
+ model.update_parameters!(data_array, @temp_weights[mi], @temp_weight_per_datum[mi])
83
+ end
84
+ update_weights!(data_array)
85
+ end
86
+
87
+ # output types
88
+ # :full (default)
89
+ # :separate_only
90
+ # :mixture_only
91
+ def to_gnuplot(type = :full)
92
+ # output each model (currently assume Gaussian)
93
+ output = []
94
+ @models.each_with_index do |model, mi|
95
+ output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot_with_title(@weights[mi])}"
96
+ end
97
+ separate = output.join(", ")
98
+ # output mixture model (currently assume Gaussian Mixture model)
99
+ output = []
100
+ @models.each_with_index do |model, mi|
101
+ output << "#{@weights[mi].round(DIGIT)} * #{model.to_gnuplot}"
102
+ end
103
+ mixture = output.join(" + ")
104
+ case type
105
+ when :separate_only
106
+ return separate
107
+ when :mixture_only
108
+ return mixture
109
+ end
110
+ "#{separate}, #{mixture}"
111
+ end
112
+
113
+ def debug_output
114
+ <<-DEBUG_OUT
115
+ @weights=#{@weights.inspect}
116
+ @temp_weights=#{@temp_weights.inspect}
117
+ @models
118
+ #{@models.inspect}
119
+ DEBUG_OUT
120
+ end
121
+ end
122
+ end