num4regana 0.0.5-java → 0.0.6-java
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/ext/num4regana/AbstractGLM.java +3 -6
- data/ext/num4regana/AbstractGLMM.java +116 -43
- data/ext/num4regana/LogitBayesRegAna.java +3 -2
- data/ext/num4regana/PoissonBayesRegAna.java +3 -2
- data/ext/num4regana/PoissonHierBayesRegAna.java +4 -8
- data/lib/num4glmmregana.rb +2 -2
- data/lib/num4hbmregana.rb +1 -0
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 07475023e86feba9a9a2013d572cdc433753da1cb570aa60fb3eb3662b09f3eb
|
4
|
+
data.tar.gz: 24a67fb47fe74071f39ae3d9d078071c1de5fce5b15afb74625354b34777b6be
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: '058141be9257a86bc72b745402956142d015041e8812886a83ce2c49ac64993bc9fb1c864e2c7a093d0e42ba69f4dc45305f1eabb4fcffc4c0bc32d73333afb8'
|
7
|
+
data.tar.gz: db0acabd519684c4db318051df5467f50b0879aabdff44c426ff2c6e96559f762d286d1f8f7642052ff2dffa75847dc6f315610cad6c1d8fe268e14a251e6d22
|
data/CHANGELOG.md
CHANGED
@@ -32,9 +32,7 @@ abstract class AbstractGLM {
|
|
32
32
|
for(int i = 0; i < yi.length; i++) {
|
33
33
|
xi[0] = 1.0;
|
34
34
|
System.arraycopy(xij[i], 0, xi, 1, xij[0].length);
|
35
|
-
|
36
|
-
double q = regression(b, xi);
|
37
|
-
double p = linkFunc(q);
|
35
|
+
double p = linkFunc(regression(b, xi));
|
38
36
|
|
39
37
|
for(int j = 0; j < xi.length; j++) {
|
40
38
|
ei[j] += (p - yi[i]) * xi[j];
|
@@ -52,12 +50,11 @@ abstract class AbstractGLM {
|
|
52
50
|
xi[0] = 1.0;
|
53
51
|
System.arraycopy(xij[i], 0, xi, 1, xij[0].length);
|
54
52
|
|
55
|
-
double
|
56
|
-
double p = linkFunc(q);
|
53
|
+
double p = linkFunc(regression(b, xi));
|
57
54
|
|
58
55
|
l += Math.log(p);
|
59
56
|
}
|
60
57
|
return l;
|
61
|
-
}
|
58
|
+
}
|
62
59
|
}
|
63
60
|
|
@@ -10,28 +10,17 @@ abstract class AbstractGLMM {
|
|
10
10
|
// (メトロポリス法,ギブスサンプリング)
|
11
11
|
protected double[] mcmcGS(double[] yi, double[] b, double[][] xij) {
|
12
12
|
BetaDistribution beDist = new BetaDistribution(50, 50);
|
13
|
-
BetaDistribution beDist2 = new BetaDistribution(1, 1); // 確率用
|
14
13
|
double[] newB = new double[b.length];
|
15
|
-
double oldL = 0.0;
|
16
|
-
double newL = 0.0;
|
17
14
|
|
18
15
|
for(int i = 0; i < b.length; i++) {
|
19
16
|
newB = Arrays.copyOf(b, b.length);
|
20
|
-
oldL = calcLx(b,xij);
|
21
17
|
newB[i] = beDist.sample();
|
22
|
-
newL = calcLx(newB,xij);
|
23
18
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
double r2 = beDist2.sample();
|
30
|
-
|
31
|
-
if (r2 < (1.0 - r)) {
|
32
|
-
b[i] = newB[i];
|
33
|
-
}
|
34
|
-
}
|
19
|
+
b[i] = mcmcSample(
|
20
|
+
calcLx(b,xij), // oldL
|
21
|
+
calcLx(newB,xij), // newL
|
22
|
+
new double[] {b[i], newB[i]} // bTbl: [0]=> oldB, [1]=> newB
|
23
|
+
);
|
35
24
|
}
|
36
25
|
return b;
|
37
26
|
}
|
@@ -47,16 +36,94 @@ abstract class AbstractGLMM {
|
|
47
36
|
// EMアルゴリズム
|
48
37
|
protected double[] mcmcEM(double[] yi, double[] b, double[][] xij) {
|
49
38
|
double[] newB = new double[b.length];
|
50
|
-
|
51
|
-
double[] bE = calcEStep(yi, b);
|
39
|
+
double[][] bE = calcEStep(yi, b, xij);
|
52
40
|
double[] bM = calcMStep(yi, bE, xij);
|
53
41
|
|
54
42
|
for(int i = 0; i < newB.length; i++) {
|
55
|
-
newB[i] =
|
43
|
+
newB[i] = bM[i];
|
56
44
|
}
|
57
45
|
return newB;
|
58
46
|
}
|
47
|
+
/*********************************/
|
48
|
+
/* interface define */
|
49
|
+
/*********************************/
|
50
|
+
/*********************************/
|
51
|
+
/* class define */
|
52
|
+
/*********************************/
|
53
|
+
private static class ArraysFillEx {
|
54
|
+
public static void fill(Object array, Object value) {
|
55
|
+
// 第一引数が配列か判定
|
56
|
+
Class<?> type = array.getClass();
|
57
|
+
if (!type.isArray()) {
|
58
|
+
throw new IllegalArgumentException("not array");
|
59
|
+
}
|
60
|
+
|
61
|
+
// クラスの型を判定
|
62
|
+
String arrayClassName = array.getClass().getSimpleName()
|
63
|
+
.replace("[]", "")
|
64
|
+
.toLowerCase();
|
65
|
+
String valueClassName = value.getClass().getSimpleName()
|
66
|
+
.toLowerCase()
|
67
|
+
.replace("character", "char")
|
68
|
+
.replace("integer", "int");
|
69
|
+
if (!arrayClassName.equals(valueClassName)) {
|
70
|
+
throw new IllegalArgumentException("does not matc");
|
71
|
+
}
|
72
|
+
|
73
|
+
// 処理
|
74
|
+
if (type.getComponentType().isArray()) {
|
75
|
+
for(Object o: (Object[])array) {
|
76
|
+
fill(o, value);
|
77
|
+
}
|
78
|
+
}
|
79
|
+
else if (array instanceof boolean[]) {
|
80
|
+
Arrays.fill((boolean[])array, (boolean)value);
|
81
|
+
}
|
82
|
+
else if (array instanceof char[]) {
|
83
|
+
Arrays.fill((char[])array, (char)value);
|
84
|
+
}
|
85
|
+
else if (array instanceof byte[]) {
|
86
|
+
Arrays.fill((byte[])array, (byte)value);
|
87
|
+
}
|
88
|
+
else if (array instanceof short[]) {
|
89
|
+
Arrays.fill((short[])array, (short)value);
|
90
|
+
}
|
91
|
+
else if (array instanceof int[]) {
|
92
|
+
Arrays.fill((int[])array, (int)value);
|
93
|
+
}
|
94
|
+
else if (array instanceof long[]) {
|
95
|
+
Arrays.fill((long[])array, (long)value);
|
96
|
+
}
|
97
|
+
else if (array instanceof float[]) {
|
98
|
+
Arrays.fill((float[])array, (float)value);
|
99
|
+
}
|
100
|
+
else if (array instanceof double[]) {
|
101
|
+
Arrays.fill((double[])array, (double)value);
|
102
|
+
}
|
103
|
+
else {
|
104
|
+
Arrays.fill((Object[])array, value);
|
105
|
+
}
|
106
|
+
}
|
107
|
+
}
|
59
108
|
/* ------------------------------------------------------------------ */
|
109
|
+
private double mcmcSample(double oldL, double newL, double[] bTbl) {
|
110
|
+
double r = newL / oldL;
|
111
|
+
BetaDistribution beDist2 = new BetaDistribution(1, 1); // 確率用
|
112
|
+
double b;
|
113
|
+
|
114
|
+
b = bTbl[0];
|
115
|
+
if (r > 1.0) {
|
116
|
+
b = bTbl[1];
|
117
|
+
}
|
118
|
+
else {
|
119
|
+
double r2 = beDist2.sample();
|
120
|
+
|
121
|
+
if (r2 < (1.0 - r)) {
|
122
|
+
b = bTbl[1];
|
123
|
+
}
|
124
|
+
}
|
125
|
+
return b;
|
126
|
+
}
|
60
127
|
// 尤度計算(パラメータ)
|
61
128
|
private double calcLx(double[] b, double[][] xij) {
|
62
129
|
double l = 1.0;
|
@@ -65,10 +132,11 @@ abstract class AbstractGLMM {
|
|
65
132
|
for(int i = 0; i < xij.length; i++) {
|
66
133
|
xi[0] = 1.0;
|
67
134
|
System.arraycopy(xij[i], 0, xi, 1, xij[0].length);
|
68
|
-
double q =
|
69
|
-
|
135
|
+
double q = linkFunc(
|
136
|
+
regression(b, xi, nDist.sample())
|
137
|
+
);
|
70
138
|
|
71
|
-
l *=
|
139
|
+
l *= q;
|
72
140
|
}
|
73
141
|
return l;
|
74
142
|
}
|
@@ -80,46 +148,51 @@ abstract class AbstractGLMM {
|
|
80
148
|
for(int i = 0; i < xij.length; i++) {
|
81
149
|
xi[0] = 1.0;
|
82
150
|
System.arraycopy(xij[i], 0, xi, 1, xij[0].length);
|
83
|
-
double q =
|
84
|
-
|
151
|
+
double q = linkFunc(
|
152
|
+
regression(b, xi, nDist.sample())
|
153
|
+
);
|
85
154
|
|
86
|
-
l += Math.log(
|
155
|
+
l += Math.log(q);
|
87
156
|
}
|
88
157
|
return l;
|
89
158
|
}
|
90
159
|
// E-Step
|
91
|
-
// (
|
92
|
-
private double[] calcEStep(double[] yi, double[] b) {
|
93
|
-
double[] bh = new double[b.length];
|
160
|
+
// (Expetation:自己エントロピー)
|
161
|
+
private double[][] calcEStep(double[] yi, double[] b, double[][] xij) {
|
162
|
+
double[][] bh = new double[yi.length][b.length];
|
163
|
+
double[] xi = new double[b.length];
|
94
164
|
|
95
|
-
|
96
|
-
for(int
|
97
|
-
|
98
|
-
|
165
|
+
ArraysFillEx.fill(bh, 0.0);
|
166
|
+
for(int i = 0; i < yi.length; i++) {
|
167
|
+
xi[0] = 1.0;
|
168
|
+
System.arraycopy(xij[i], 0, xi, 1, xij[i].length);
|
169
|
+
double p = yi[i];
|
170
|
+
double q = linkFunc(regression(b, xi, nDist.sample()));
|
99
171
|
|
100
|
-
|
172
|
+
for(int j = 0; j < b.length; j++) {
|
173
|
+
bh[i][j] =
|
174
|
+
Math.log(p * xi[j]) - q * (Math.log(q) - Math.log(p * xi[j]));
|
101
175
|
}
|
102
|
-
bh[j] *= -1;
|
103
176
|
}
|
104
177
|
return bh;
|
105
178
|
}
|
106
179
|
// M-Step
|
107
|
-
// (KLダイバージェンス)
|
108
|
-
private double[] calcMStep(double[] yi, double[]
|
109
|
-
double[] xi = new double[
|
110
|
-
double[] ei = new double[
|
180
|
+
// (Maximiation:KLダイバージェンス)
|
181
|
+
private double[] calcMStep(double[] yi, double[][] q, double[][] xij) {
|
182
|
+
double[] xi = new double[1 + xij[0].length];
|
183
|
+
double[] ei = new double[1 + xij[0].length];
|
111
184
|
|
112
185
|
Arrays.fill(ei, 0.0);
|
113
|
-
for(int j = 0; j <
|
186
|
+
for(int j = 0; j < xi.length; j++) {
|
114
187
|
for(int i = 0; i < xij.length; i++) {
|
115
188
|
xi[0] = 1.0;
|
116
189
|
System.arraycopy(xij[i], 0, xi, 1, xij[0].length);
|
117
190
|
|
118
|
-
double
|
119
|
-
double p = linkFunc(regression(b, xi, 0));
|
191
|
+
double p = yi[i];
|
120
192
|
|
121
|
-
ei[j] += q * (Math.log(
|
122
|
-
}
|
193
|
+
ei[j] += q[i][j] * (Math.log(p * xi[j]) - Math.log(q[i][j]));
|
194
|
+
}
|
195
|
+
ei[j] = -1 * ei[j];
|
123
196
|
}
|
124
197
|
return ei;
|
125
198
|
}
|
@@ -45,14 +45,15 @@ public class LogitBayesRegAna extends AbstractGLMM {
|
|
45
45
|
}
|
46
46
|
return meanB;
|
47
47
|
}
|
48
|
-
// q = b0 + b1 * x0
|
48
|
+
// q = b0 + b1 * x0 + r
|
49
|
+
// (ランダム切片モデル)
|
49
50
|
double regression(double[] b, double[] xi, double r) {
|
50
51
|
double ret = 0.0;
|
51
52
|
|
52
53
|
for(int i = 0; i < xi.length; i++) {
|
53
54
|
ret += b[i] * xi[i];
|
54
55
|
}
|
55
|
-
return ret;
|
56
|
+
return ret + r;
|
56
57
|
}
|
57
58
|
// p = 1 / (1 + exp( -q))
|
58
59
|
double linkFunc(double q) {
|
@@ -33,14 +33,15 @@ public class PoissonBayesRegAna extends AbstractGLMM {
|
|
33
33
|
}
|
34
34
|
return b;
|
35
35
|
}
|
36
|
-
// q = b0 + b1 * x0
|
36
|
+
// q = b0 + b1 * x0 + r
|
37
|
+
// (ランダム切片モデル)
|
37
38
|
double regression(double[] b, double[] xi, double r) {
|
38
39
|
double ret = 0.0;
|
39
40
|
|
40
41
|
for(int i = 0; i < xi.length; i++) {
|
41
42
|
ret += b[i] * xi[i];
|
42
43
|
}
|
43
|
-
return ret;
|
44
|
+
return ret + r;
|
44
45
|
}
|
45
46
|
// p = exp(q)
|
46
47
|
double linkFunc(double q) {
|
@@ -12,28 +12,24 @@ public class PoissonHierBayesRegAna extends AbstractGLMM {
|
|
12
12
|
double[] b = initB(xij[0].length);
|
13
13
|
|
14
14
|
for (int i = 0; i < NUM; i++) {
|
15
|
-
b =
|
15
|
+
b = mcmcEM(yi, b, xij);
|
16
16
|
}
|
17
17
|
return new LineReg(b);
|
18
18
|
}
|
19
19
|
private double[] initB(int xsie) {
|
20
20
|
double[] b = new double[1 + xsie];
|
21
|
-
BetaDistribution beDist = new BetaDistribution(50, 50);
|
22
21
|
|
23
|
-
|
24
|
-
b[i] = beDist.sample();
|
25
|
-
}
|
22
|
+
Arrays.fill(b, 0.0);
|
26
23
|
return b;
|
27
24
|
}
|
28
|
-
// q = b0 + b1 * x0
|
29
|
-
// (ランダム切片モデル)
|
25
|
+
// q = b0 + b1 * x0
|
30
26
|
double regression(double[] b, double[] xi, double r) {
|
31
27
|
double ret = 0.0;
|
32
28
|
|
33
29
|
for(int i = 0; i < xi.length; i++) {
|
34
30
|
ret += b[i] * xi[i];
|
35
31
|
}
|
36
|
-
return ret
|
32
|
+
return ret;
|
37
33
|
}
|
38
34
|
// p = exp(q)
|
39
35
|
double linkFunc(double q) {
|
data/lib/num4glmmregana.rb
CHANGED
@@ -71,7 +71,7 @@ module Num4GLMMRegAnaLib
|
|
71
71
|
# @overload get_bic(regcoe, xij)
|
72
72
|
# @param [Hash] regcoe 回帰係数(intercept:定数項 slope:回帰係数)
|
73
73
|
# @param [Array] xij xの値(double[][])
|
74
|
-
# @return double
|
74
|
+
# @return double BIC値
|
75
75
|
# @example
|
76
76
|
# reg = {
|
77
77
|
# :intercept=> -6.2313, # 定数項
|
@@ -154,7 +154,7 @@ module Num4GLMMRegAnaLib
|
|
154
154
|
# @overload get_bic(regcoe, xij)
|
155
155
|
# @param [Hash] regcoe 回帰係数(intercept:定数項 slope:回帰係数)
|
156
156
|
# @param [Array] xij xの値(double[][])
|
157
|
-
# @return double
|
157
|
+
# @return double BIC値
|
158
158
|
# @example
|
159
159
|
# reg = {
|
160
160
|
# :intercept=>0.4341885635221602, # 定数項
|
data/lib/num4hbmregana.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: num4regana
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0.
|
4
|
+
version: 0.0.6
|
5
5
|
platform: java
|
6
6
|
authors:
|
7
7
|
- siranovel
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-10-04 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rake
|