erv 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/erv/mixture_distribution.rb +31 -12
- data/lib/erv/version.rb +1 -1
- data/test/erv/mixture_distribution_test.rb +69 -19
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 7dd791a7886ff932086acc0bfe0c30eb43dab514
|
4
|
+
data.tar.gz: b3b088d570cfb19d579d524f86fe35ddc731d879
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 11b2ad325261926571fe63d51c9f08227fa1785de31c57c7f60769721d9195b892965debd9278a42b1aca129e5c8a53f812a29ef900eab502cde48df16c41a6d
|
7
|
+
data.tar.gz: af27054d6de5b9e3d2285505c116957a96faaae155dcc7441d3beaf23ad582eeef9a98b4d2e0dc217c938daf867cf16af3825415fd7a85ed5ade682f8d4b77a4
|
@@ -14,13 +14,19 @@ module ERV
|
|
14
14
|
raise ArgumentError, "Please, provide at least 2 distributions!" unless confs.length >= 2
|
15
15
|
|
16
16
|
@mixture = []
|
17
|
-
|
17
|
+
weight_sum = 0.0
|
18
18
|
while dist_conf = confs.shift
|
19
19
|
# get weight ...
|
20
20
|
weight = dist_conf.delete(:weight).to_f
|
21
21
|
|
22
22
|
# ... and keep track of it
|
23
|
-
|
23
|
+
weight_sum += weight
|
24
|
+
|
25
|
+
# get amplitude
|
26
|
+
amplitude = dist_conf.fetch(:amplitude) { 1.0 }
|
27
|
+
# unless amplitude.is_a? Float
|
28
|
+
# raise ArgumentError, "Please, provide at least 2 distributions!" unless confs.length >= 2
|
29
|
+
# end
|
24
30
|
|
25
31
|
# get distribution name
|
26
32
|
dist_name = dist_conf.delete(:distribution).to_s
|
@@ -32,7 +38,12 @@ module ERV
|
|
32
38
|
distribution = ERV.const_get(klass_name).new(dist_conf)
|
33
39
|
|
34
40
|
# add distribution to mixture
|
35
|
-
@mixture << { weight: weight, distribution: distribution }
|
41
|
+
@mixture << { amplitude: amplitude, weight: weight, distribution: distribution }
|
42
|
+
end
|
43
|
+
|
44
|
+
# normalize weights
|
45
|
+
@mixture.each do |dist|
|
46
|
+
dist[:weight] /= weight_sum
|
36
47
|
end
|
37
48
|
|
38
49
|
seed = opts[:seed]
|
@@ -44,13 +55,13 @@ module ERV
|
|
44
55
|
|
45
56
|
# find index of distribution we are supposed to sample from
|
46
57
|
i = 0
|
47
|
-
while x >
|
48
|
-
x -=
|
58
|
+
while x > @mixture[i][:weight]
|
59
|
+
x -= @mixture[i][:weight]
|
49
60
|
i += 1
|
50
61
|
end
|
51
62
|
|
52
63
|
# return sample
|
53
|
-
@mixture[i][:distribution].sample
|
64
|
+
@mixture[i][:amplitude] * @mixture[i][:distribution].sample
|
54
65
|
end
|
55
66
|
|
56
67
|
def mean
|
@@ -64,16 +75,24 @@ module ERV
|
|
64
75
|
private
|
65
76
|
|
66
77
|
def calculate_mean
|
67
|
-
@mixture.inject(0.0) do |s,x|
|
68
|
-
s += (
|
78
|
+
@mixture.inject(0.0) do |s,x|
|
79
|
+
s += (# the following formula was taken from
|
80
|
+
# https://en.wikipedia.org/wiki/Mixture_Distribution#Moments
|
81
|
+
x[:weight] *
|
82
|
+
# remember: E[aX] = a E[X]
|
83
|
+
x[:amplitude] * x[:distribution].mean)
|
69
84
|
end
|
70
85
|
end
|
71
86
|
|
72
87
|
def calculate_variance
|
73
|
-
@mixture.inject(0.0) do |s,x|
|
74
|
-
s += (
|
75
|
-
|
76
|
-
|
88
|
+
@mixture.inject(0.0) do |s,x|
|
89
|
+
s += (# the following formula was taken from
|
90
|
+
# https://en.wikipedia.org/wiki/Mixture_Distribution#Moments
|
91
|
+
x[:weight] *
|
92
|
+
# remember: E[aX] = a E[X]
|
93
|
+
((x[:amplitude] * x[:distribution].mean - self.mean) ** 2 +
|
94
|
+
# remember: Var(aX) = a**2 Var(X)
|
95
|
+
x[:amplitude] ** 2 * x[:distribution].variance))
|
77
96
|
end
|
78
97
|
end
|
79
98
|
end
|
data/lib/erv/version.rb
CHANGED
@@ -9,14 +9,6 @@ describe ERV::MixtureDistribution do
|
|
9
9
|
end.must_raise ArgumentError
|
10
10
|
end
|
11
11
|
|
12
|
-
it 'should keep track of distribution weights (for normalization)' do
|
13
|
-
# create a mixture distribution with unnormalized weights
|
14
|
-
uw_md = ERV::MixtureDistribution.new([ { distribution: :exponential, rate: 1.0, weight: 100.0 },
|
15
|
-
{ distribution: :exponential, rate: 2.0, weight: 200.0 },
|
16
|
-
{ distribution: :exponential, rate: 3.0, weight: 300.0 } ])
|
17
|
-
uw_md.instance_variable_get("@weight_sum").must_equal 600.0
|
18
|
-
end
|
19
|
-
|
20
12
|
let :md do
|
21
13
|
ERV::MixtureDistribution.new([ { distribution: :exponential, rate: 1.0, weight: 0.3 },
|
22
14
|
{ distribution: :exponential, rate: 2.0, weight: 0.2 },
|
@@ -33,23 +25,81 @@ describe ERV::MixtureDistribution do
|
|
33
25
|
|
34
26
|
context 'moments' do
|
35
27
|
|
36
|
-
|
37
|
-
|
38
|
-
|
28
|
+
context 'with default amplitude' do
|
29
|
+
let :md_expected_mean do
|
30
|
+
0.3 * 1/1.0 + 0.2 * 1/2.0 + 0.5 * 1/3.0
|
31
|
+
end
|
32
|
+
|
33
|
+
let :md_expected_variance do
|
34
|
+
0.3 * ((1/1.0 - md_expected_mean) ** 2 + (1/1.0) ** 2) +
|
35
|
+
0.2 * ((1/2.0 - md_expected_mean) ** 2 + (1/2.0) ** 2) +
|
36
|
+
0.5 * ((1/3.0 - md_expected_mean) ** 2 + (1/3.0) ** 2)
|
37
|
+
end
|
39
38
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
39
|
+
it 'should correctly calculate the mean of the mixture' do
|
40
|
+
md.mean.must_equal md_expected_mean
|
41
|
+
end
|
42
|
+
|
43
|
+
it 'should correctly calculate the variance of the mixture' do
|
44
|
+
md.variance.must_equal md_expected_variance
|
45
|
+
end
|
44
46
|
end
|
45
47
|
|
46
|
-
|
47
|
-
|
48
|
+
context 'with different amplitudes' do
|
49
|
+
let :amd do
|
50
|
+
ERV::MixtureDistribution.new([ { distribution: :gaussian, amplitude: 3.0, mean: 1.0, sd: 0.1, weight: 0.3 },
|
51
|
+
{ distribution: :gaussian, amplitude: 5.0, mean: 2.0, sd: 0.2, weight: 0.2 },
|
52
|
+
{ distribution: :gaussian, amplitude: 7.0, mean: 3.0, sd: 0.3, weight: 0.5 } ])
|
53
|
+
end
|
54
|
+
|
55
|
+
let :amd_expected_mean do
|
56
|
+
0.3 * 3.0 * 1.0 +
|
57
|
+
0.2 * 5.0 * 2.0 +
|
58
|
+
0.5 * 7.0 * 3.0
|
59
|
+
end
|
60
|
+
|
61
|
+
let :amd_expected_variance do
|
62
|
+
0.3 * ((3.0 * 1.0 - amd_expected_mean) ** 2 + (3.0 * 0.1) ** 2) +
|
63
|
+
0.2 * ((5.0 * 2.0 - amd_expected_mean) ** 2 + (5.0 * 0.2) ** 2) +
|
64
|
+
0.5 * ((7.0 * 3.0 - amd_expected_mean) ** 2 + (7.0 * 0.3) ** 2)
|
65
|
+
end
|
66
|
+
|
67
|
+
it 'should correctly calculate the mean of the mixture' do
|
68
|
+
amd.mean.must_equal amd_expected_mean
|
69
|
+
end
|
70
|
+
|
71
|
+
it 'should correctly calculate the variance of the mixture' do
|
72
|
+
amd.variance.must_equal amd_expected_variance
|
73
|
+
end
|
48
74
|
end
|
49
75
|
|
50
|
-
|
51
|
-
|
76
|
+
context 'with unnormalized weights' do
|
77
|
+
let :uwmd do
|
78
|
+
ERV::MixtureDistribution.new([ { distribution: :exponential, rate: 1.0, weight: 300 },
|
79
|
+
{ distribution: :exponential, rate: 2.0, weight: 200 },
|
80
|
+
{ distribution: :exponential, rate: 3.0, weight: 500 } ])
|
81
|
+
end
|
82
|
+
|
83
|
+
let :uwmd_expected_mean do
|
84
|
+
0.3 * 1/1.0 + 0.2 * 1/2.0 + 0.5 * 1/3.0
|
85
|
+
end
|
86
|
+
|
87
|
+
let :uwmd_expected_variance do
|
88
|
+
0.3 * ((1/1.0 - uwmd_expected_mean) ** 2 + (1/1.0) ** 2) +
|
89
|
+
0.2 * ((1/2.0 - uwmd_expected_mean) ** 2 + (1/2.0) ** 2) +
|
90
|
+
0.5 * ((1/3.0 - uwmd_expected_mean) ** 2 + (1/3.0) ** 2)
|
91
|
+
end
|
92
|
+
|
93
|
+
it 'should correctly calculate the mean of the mixture' do
|
94
|
+
uwmd.mean.must_equal uwmd_expected_mean
|
95
|
+
end
|
96
|
+
|
97
|
+
it 'should correctly calculate the variance of the mixture' do
|
98
|
+
uwmd.variance.must_equal uwmd_expected_variance
|
99
|
+
end
|
100
|
+
|
52
101
|
end
|
53
102
|
|
54
103
|
end
|
104
|
+
|
55
105
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: erv
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Mauro Tortonesi
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2017-01
|
11
|
+
date: 2017-02-01 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: bundler
|
@@ -80,7 +80,7 @@ dependencies:
|
|
80
80
|
- - "~>"
|
81
81
|
- !ruby/object:Gem::Version
|
82
82
|
version: 0.0.3
|
83
|
-
description: erv-0.
|
83
|
+
description: erv-0.3.0
|
84
84
|
email:
|
85
85
|
- mauro.tortonesi@unife.it
|
86
86
|
executables: []
|