mjai-manue 0.0.1
Sign up to get free protection for your applications and to get access to all the features.
- data/bin/mjai-manue +10 -0
- data/lib/mjai/manue/danger_estimator.rb +730 -0
- data/lib/mjai/manue/hora_points_estimate.rb +488 -0
- data/lib/mjai/manue/hora_probability_estimator.rb +328 -0
- data/lib/mjai/manue/mjai_manue_command.rb +32 -0
- data/lib/mjai/manue/player.rb +373 -0
- data/share/danger.all.tree +0 -0
- data/share/hora_prob.marshal +0 -0
- metadata +65 -0
@@ -0,0 +1,730 @@
|
|
1
|
+
# coding: utf-8
|
2
|
+
|
3
|
+
require "set"
|
4
|
+
require "optparse"
|
5
|
+
|
6
|
+
require "mjai/pai"
|
7
|
+
|
8
|
+
|
9
|
+
module Mjai
|
10
|
+
|
11
|
+
module Manue
|
12
|
+
|
13
|
+
|
14
|
+
class DangerEstimator
|
15
|
+
|
16
|
+
class Scene
|
17
|
+
|
18
|
+
URASUJI_INV_MAP = {
|
19
|
+
1 => [5],
|
20
|
+
2 => [1, 6],
|
21
|
+
3 => [2, 7],
|
22
|
+
4 => [3, 5, 8],
|
23
|
+
5 => [1, 4, 6, 9],
|
24
|
+
6 => [2, 5, 7],
|
25
|
+
7 => [3, 8],
|
26
|
+
8 => [4, 9],
|
27
|
+
9 => [5],
|
28
|
+
}
|
29
|
+
|
30
|
+
SENKISUJI_INV_MAP = {
|
31
|
+
3 => [1, 8],
|
32
|
+
4 => [2, 9],
|
33
|
+
5 => [3, 7],
|
34
|
+
6 => [1, 8],
|
35
|
+
7 => [2, 9],
|
36
|
+
}
|
37
|
+
|
38
|
+
@@feature_names = []
|
39
|
+
|
40
|
+
def self.define_feature(name, &block)
|
41
|
+
define_method(name, &block)
|
42
|
+
@@feature_names.push(name)
|
43
|
+
end
|
44
|
+
|
45
|
+
def self.feature_names
|
46
|
+
return @@feature_names
|
47
|
+
end
|
48
|
+
|
49
|
+
def initialize(game, me, dapai, reacher, prereach_sutehais)
|
50
|
+
|
51
|
+
@game = game
|
52
|
+
@dapai = dapai
|
53
|
+
@me = me
|
54
|
+
@reacher = reacher
|
55
|
+
@prereach_sutehais = prereach_sutehais
|
56
|
+
|
57
|
+
tehais = @me.tehais.dup()
|
58
|
+
tehais.push(@dapai) if @dapai
|
59
|
+
@anpai_set = to_pai_set(reacher.anpais)
|
60
|
+
@prereach_sutehai_set = to_pai_set(@prereach_sutehais)
|
61
|
+
@early_sutehai_set = to_pai_set(@prereach_sutehais[0...(@prereach_sutehais.size / 2)])
|
62
|
+
@late_sutehai_set = to_pai_set(@prereach_sutehais[(@prereach_sutehais.size / 2)..-1])
|
63
|
+
@dora_set = to_pai_set(@game.doras)
|
64
|
+
@tehai_set = to_pai_set(tehais)
|
65
|
+
|
66
|
+
visible = []
|
67
|
+
visible += @game.doras
|
68
|
+
visible += @me.tehais
|
69
|
+
for player in @game.players
|
70
|
+
visible += player.ho + player.furos.map(){ |f| f.pais }.flatten()
|
71
|
+
end
|
72
|
+
@visible_set = to_pai_set(visible)
|
73
|
+
|
74
|
+
@candidates = tehais.
|
75
|
+
map(){ |pai| pai.remove_red() }.
|
76
|
+
uniq().
|
77
|
+
select(){ |pai| !@anpai_set.has_key?(pai) }
|
78
|
+
|
79
|
+
end
|
80
|
+
|
81
|
+
attr_reader(:candidates)
|
82
|
+
|
83
|
+
def to_pai_set(pais)
|
84
|
+
pai_set = Hash.new(0)
|
85
|
+
for pai in pais
|
86
|
+
pai_set[pai.remove_red()] += 1
|
87
|
+
end
|
88
|
+
return pai_set
|
89
|
+
end
|
90
|
+
|
91
|
+
# pai is without red.
|
92
|
+
# Use bit vector to make match? faster.
|
93
|
+
def feature_vector(pai)
|
94
|
+
return DangerEstimator.bool_array_to_bit_vector(
|
95
|
+
@@feature_names.map(){ |s| __send__(s, pai) })
|
96
|
+
end
|
97
|
+
|
98
|
+
def anpai?(pai)
|
99
|
+
return @anpai_set.has_key?(pai.remove_red())
|
100
|
+
end
|
101
|
+
|
102
|
+
define_feature("tsupai") do |pai|
|
103
|
+
return pai.type == "t"
|
104
|
+
end
|
105
|
+
|
106
|
+
define_feature("suji") do |pai|
|
107
|
+
if pai.type == "t"
|
108
|
+
return false
|
109
|
+
else
|
110
|
+
return get_suji_numbers(pai).all?(){ |n| @anpai_set.has_key?(Pai.new(pai.type, n)) }
|
111
|
+
end
|
112
|
+
end
|
113
|
+
|
114
|
+
# 片筋 or 筋
|
115
|
+
define_feature("weak_suji") do |pai|
|
116
|
+
return suji_of(pai, @anpai_set)
|
117
|
+
end
|
118
|
+
|
119
|
+
define_feature("reach_suji") do |pai|
|
120
|
+
reach_pai = @prereach_sutehais[-1].remove_red()
|
121
|
+
if pai.type == "t" || reach_pai.type != pai.type || pai.number == 1 || pai.number == 9
|
122
|
+
return false
|
123
|
+
else
|
124
|
+
suji_numbers = get_suji_numbers(pai)
|
125
|
+
return suji_numbers.all?(){ |n| @prereach_sutehai_set.include?(Pai.new(pai.type, n)) } &&
|
126
|
+
suji_numbers.include?(reach_pai.number) &&
|
127
|
+
@prereach_sutehai_set[reach_pai] == 1
|
128
|
+
end
|
129
|
+
end
|
130
|
+
|
131
|
+
# http://ja.wikipedia.org/wiki/%E7%AD%8B_(%E9%BA%BB%E9%9B%80)#.E8.A3.8F.E3.82.B9.E3.82.B8
|
132
|
+
define_feature("urasuji") do |pai|
|
133
|
+
return urasuji_of(pai, @prereach_sutehai_set)
|
134
|
+
end
|
135
|
+
|
136
|
+
define_feature("early_urasuji") do |pai|
|
137
|
+
return urasuji_of(pai, @early_sutehai_set)
|
138
|
+
end
|
139
|
+
|
140
|
+
define_feature("reach_urasuji") do |pai|
|
141
|
+
return urasuji_of(pai, to_pai_set([self.reach_pai]))
|
142
|
+
end
|
143
|
+
|
144
|
+
# http://ja.wikipedia.org/wiki/%E7%AD%8B_(%E9%BA%BB%E9%9B%80)#.E9.96.93.E5.9B.9B.E9.96.93
|
145
|
+
define_feature("aida4ken") do |pai|
|
146
|
+
if pai.type == "t"
|
147
|
+
return false
|
148
|
+
else
|
149
|
+
return ((2..5).include?(pai.number) &&
|
150
|
+
@prereach_sutehai_set.has_key?(Pai.new(pai.type, pai.number - 1)) &&
|
151
|
+
@prereach_sutehai_set.has_key?(Pai.new(pai.type, pai.number + 4))) ||
|
152
|
+
((5..8).include?(pai.number) &&
|
153
|
+
@prereach_sutehai_set.has_key?(Pai.new(pai.type, pai.number - 4)) &&
|
154
|
+
@prereach_sutehai_set.has_key?(Pai.new(pai.type, pai.number + 1)))
|
155
|
+
end
|
156
|
+
end
|
157
|
+
|
158
|
+
# http://ja.wikipedia.org/wiki/%E7%AD%8B_(%E9%BA%BB%E9%9B%80)#.E3.81.BE.E3.81.9F.E3.81.8E.E3.82.B9.E3.82.B8
|
159
|
+
define_feature("matagisuji") do |pai|
|
160
|
+
return matagisuji_of(pai, @prereach_sutehai_set)
|
161
|
+
end
|
162
|
+
|
163
|
+
define_feature("late_matagisuji") do |pai|
|
164
|
+
return matagisuji_of(pai, @late_sutehai_set)
|
165
|
+
end
|
166
|
+
|
167
|
+
# http://ja.wikipedia.org/wiki/%E7%AD%8B_(%E9%BA%BB%E9%9B%80)#.E7.96.9D.E6.B0.97.E3.82.B9.E3.82.B8
|
168
|
+
define_feature("senkisuji") do |pai|
|
169
|
+
return senkisuji_of(pai, @prereach_sutehai_set)
|
170
|
+
end
|
171
|
+
|
172
|
+
define_feature("early_senkisuji") do |pai|
|
173
|
+
return senkisuji_of(pai, @early_sutehai_set)
|
174
|
+
end
|
175
|
+
|
176
|
+
define_feature("outer_prereach_sutehai") do |pai|
|
177
|
+
return outer(pai, @prereach_sutehai_set)
|
178
|
+
end
|
179
|
+
|
180
|
+
define_feature("outer_early_sutehai") do |pai|
|
181
|
+
return outer(pai, @early_sutehai_set)
|
182
|
+
end
|
183
|
+
|
184
|
+
(0..3).each() do |n|
|
185
|
+
define_feature("chances<=#{n}") do |pai|
|
186
|
+
return n_chance_or_less(pai, n)
|
187
|
+
end
|
188
|
+
end
|
189
|
+
|
190
|
+
(1..3).each() do |i|
|
191
|
+
define_feature("visible>=%d" % i) do |pai|
|
192
|
+
return visible_n_or_more(pai, i)
|
193
|
+
end
|
194
|
+
end
|
195
|
+
|
196
|
+
(2..5).each() do |i|
|
197
|
+
define_feature("%d<=n<=%d" % [i, 10 - i]) do |pai|
|
198
|
+
return num_n_or_inner(pai, i)
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
define_feature("dora") do |pai|
|
203
|
+
return @dora_set.has_key?(pai)
|
204
|
+
end
|
205
|
+
|
206
|
+
define_feature("dora_suji") do |pai|
|
207
|
+
return suji_of(pai, @dora_set)
|
208
|
+
end
|
209
|
+
|
210
|
+
define_feature("dora_matagi") do |pai|
|
211
|
+
return matagisuji_of(pai, @dora_set)
|
212
|
+
end
|
213
|
+
|
214
|
+
(2..4).each() do |i|
|
215
|
+
define_feature("in_tehais>=#{i}") do |pai|
|
216
|
+
return @tehai_set[pai] >= i
|
217
|
+
end
|
218
|
+
end
|
219
|
+
|
220
|
+
(2..4).each() do |i|
|
221
|
+
define_feature("suji_in_tehais>=#{i}") do |pai|
|
222
|
+
if pai.type == "t"
|
223
|
+
return false
|
224
|
+
else
|
225
|
+
return get_suji_numbers(pai).any?(){ |n| @tehai_set[Pai.new(pai.type, n)] >= i }
|
226
|
+
end
|
227
|
+
end
|
228
|
+
end
|
229
|
+
|
230
|
+
(1..2).each() do |i|
|
231
|
+
(1..(i * 2)).each() do |j|
|
232
|
+
define_feature("+-#{i}_in_prereach_sutehais>=#{j}") do |pai|
|
233
|
+
n_or_more_of_neighbors_in_prereach_sutehais(pai, j, i)
|
234
|
+
end
|
235
|
+
end
|
236
|
+
end
|
237
|
+
|
238
|
+
(1..2).each() do |i|
|
239
|
+
define_feature("#{i}_outer_prereach_sutehai") do |pai|
|
240
|
+
n_outer_prereach_sutehai(pai, i)
|
241
|
+
end
|
242
|
+
end
|
243
|
+
|
244
|
+
(1..2).each() do |i|
|
245
|
+
define_feature("#{i}_inner_prereach_sutehai") do |pai|
|
246
|
+
n_outer_prereach_sutehai(pai, -i)
|
247
|
+
end
|
248
|
+
end
|
249
|
+
|
250
|
+
(1..8).each() do |i|
|
251
|
+
define_feature("same_type_in_prereach>=#{i}") do |pai|
|
252
|
+
if pai.type == "t"
|
253
|
+
return false
|
254
|
+
else
|
255
|
+
num_same_type = (1..9).
|
256
|
+
select(){ |n| @prereach_sutehai_set.has_key?(Pai.new(pai.type, n)) }.
|
257
|
+
size
|
258
|
+
return num_same_type >= i
|
259
|
+
end
|
260
|
+
end
|
261
|
+
end
|
262
|
+
|
263
|
+
define_feature("fanpai") do |pai|
|
264
|
+
return fanpai_fansu(pai) >= 1
|
265
|
+
end
|
266
|
+
|
267
|
+
define_feature("ryenfonpai") do |pai|
|
268
|
+
return fanpai_fansu(pai) >= 2
|
269
|
+
end
|
270
|
+
|
271
|
+
define_feature("sangenpai") do |pai|
|
272
|
+
return pai.type == "t" && pai.number >= 5
|
273
|
+
end
|
274
|
+
|
275
|
+
define_feature("fonpai") do |pai|
|
276
|
+
return pai.type == "t" && pai.number < 5
|
277
|
+
end
|
278
|
+
|
279
|
+
define_feature("bakaze") do |pai|
|
280
|
+
return pai == @game.bakaze
|
281
|
+
end
|
282
|
+
|
283
|
+
define_feature("jikaze") do |pai|
|
284
|
+
return pai == @reacher.jikaze
|
285
|
+
end
|
286
|
+
|
287
|
+
def n_outer_prereach_sutehai(pai, n)
|
288
|
+
if pai.type == "t"
|
289
|
+
return false
|
290
|
+
elsif pai.number < 6 - n || pai.number > 4 + n
|
291
|
+
n_inner_pai = Pai.new(pai.type, pai.number < 5 ? pai.number + n : pai.number - n)
|
292
|
+
return @prereach_sutehai_set.include?(n_inner_pai)
|
293
|
+
else
|
294
|
+
return false
|
295
|
+
end
|
296
|
+
end
|
297
|
+
|
298
|
+
def n_or_more_of_neighbors_in_prereach_sutehais(pai, n, neighbor_distance)
|
299
|
+
if pai.type == "t"
|
300
|
+
return false
|
301
|
+
else
|
302
|
+
num_neighbors =
|
303
|
+
((pai.number - neighbor_distance)..(pai.number + neighbor_distance)).
|
304
|
+
select(){ |n| @prereach_sutehai_set.has_key?(Pai.new(pai.type, n)) }.
|
305
|
+
size
|
306
|
+
return num_neighbors >= n
|
307
|
+
end
|
308
|
+
end
|
309
|
+
|
310
|
+
def suji_of(pai, target_pai_set)
|
311
|
+
if pai.type == "t"
|
312
|
+
return false
|
313
|
+
else
|
314
|
+
return get_suji_numbers(pai).any?(){ |n| target_pai_set.has_key?(Pai.new(pai.type, n)) }
|
315
|
+
end
|
316
|
+
end
|
317
|
+
|
318
|
+
def get_suji_numbers(pai)
|
319
|
+
return [pai.number - 3, pai.number + 3].select(){ |n| (1..9).include?(n) }
|
320
|
+
end
|
321
|
+
|
322
|
+
def n_chance_or_less(pai, n)
|
323
|
+
if pai.type == "t" || (4..6).include?(pai.number)
|
324
|
+
return false
|
325
|
+
else
|
326
|
+
return (1..2).any?() do |i|
|
327
|
+
kabe_pai = Pai.new(pai.type, pai.number + (pai.number < 5 ? i : -i))
|
328
|
+
@visible_set[kabe_pai] >= 4 - n
|
329
|
+
end
|
330
|
+
end
|
331
|
+
end
|
332
|
+
|
333
|
+
def num_n_or_inner(pai, n)
|
334
|
+
return pai.type != "t" && pai.number >= n && pai.number <= 10 - n
|
335
|
+
end
|
336
|
+
|
337
|
+
def visible_n_or_more(pai, n)
|
338
|
+
# n doesn't include itself.
|
339
|
+
return @visible_set[pai] >= n + 1
|
340
|
+
end
|
341
|
+
|
342
|
+
def urasuji_of(pai, target_pai_set)
|
343
|
+
if pai.type == "t"
|
344
|
+
return false
|
345
|
+
else
|
346
|
+
urasuji_numbers = URASUJI_INV_MAP[pai.number]
|
347
|
+
return urasuji_numbers.any?(){ |n| target_pai_set.has_key?(Pai.new(pai.type, n)) }
|
348
|
+
end
|
349
|
+
end
|
350
|
+
|
351
|
+
def senkisuji_of(pai, target_pai_set)
|
352
|
+
if pai.type == "t"
|
353
|
+
return false
|
354
|
+
else
|
355
|
+
senkisuji_numbers = SENKISUJI_INV_MAP[pai.number]
|
356
|
+
return senkisuji_numbers &&
|
357
|
+
senkisuji_numbers.any?(){ |n| target_pai_set.has_key?(Pai.new(pai.type, n)) }
|
358
|
+
end
|
359
|
+
end
|
360
|
+
|
361
|
+
def matagisuji_of(pai, target_pai_set)
|
362
|
+
if pai.type == "t"
|
363
|
+
return false
|
364
|
+
else
|
365
|
+
matagisuji_numbers = []
|
366
|
+
if pai.number >= 4
|
367
|
+
matagisuji_numbers += [pai.number - 2, pai.number - 1]
|
368
|
+
end
|
369
|
+
if pai.number <= 6
|
370
|
+
matagisuji_numbers += [pai.number + 1, pai.number + 2]
|
371
|
+
end
|
372
|
+
return matagisuji_numbers.any?(){ |n| target_pai_set.has_key?(Pai.new(pai.type, n)) }
|
373
|
+
end
|
374
|
+
end
|
375
|
+
|
376
|
+
def outer(pai, target_pai_set)
|
377
|
+
if pai.type == "t" || pai.number == 5
|
378
|
+
return false
|
379
|
+
else
|
380
|
+
inner_numbers = pai.number < 5 ? ((pai.number + 1)..5) : (5..(pai.number - 1))
|
381
|
+
return inner_numbers.any?(){ |n| target_pai_set.has_key?(Pai.new(pai.type, n)) }
|
382
|
+
end
|
383
|
+
end
|
384
|
+
|
385
|
+
def reach_pai
|
386
|
+
return @prereach_sutehais[-1]
|
387
|
+
end
|
388
|
+
|
389
|
+
def fanpai_fansu(pai)
|
390
|
+
if pai.type == "t" && pai.number >= 5
|
391
|
+
return 1
|
392
|
+
else
|
393
|
+
return (pai == @game.bakaze ? 1 : 0) + (pai == @reacher.jikaze ? 1 : 0)
|
394
|
+
end
|
395
|
+
end
|
396
|
+
|
397
|
+
end
|
398
|
+
|
399
|
+
StoredKyoku = Struct.new(:scenes)
|
400
|
+
StoredScene = Struct.new(:candidates)
|
401
|
+
DecisionNode = Struct.new(
|
402
|
+
:average_prob, :conf_interval, :num_samples, :feature_name, :positive, :negative)
|
403
|
+
|
404
|
+
class DecisionTree
|
405
|
+
|
406
|
+
def initialize(path)
|
407
|
+
@root = open(path, "rb"){ |f| Marshal.load(f) }
|
408
|
+
end
|
409
|
+
|
410
|
+
def estimate_prob(scene, pai)
|
411
|
+
feature_vector = scene.feature_vector(pai.remove_red())
|
412
|
+
p [pai, DangerEstimator.feature_vector_to_str(feature_vector)]
|
413
|
+
node = @root
|
414
|
+
while node.feature_name
|
415
|
+
if DangerEstimator.get_feature_value(feature_vector, node.feature_name)
|
416
|
+
node = node.positive
|
417
|
+
else
|
418
|
+
node = node.negative
|
419
|
+
end
|
420
|
+
end
|
421
|
+
return node.average_prob
|
422
|
+
end
|
423
|
+
|
424
|
+
end
|
425
|
+
|
426
|
+
def initialize()
|
427
|
+
@min_gap = 0.0
|
428
|
+
end
|
429
|
+
|
430
|
+
attr_accessor(:verbose)
|
431
|
+
attr_accessor(:min_gap)
|
432
|
+
|
433
|
+
def extract_features_from_files(input_paths, output_path)
|
434
|
+
$stderr.puts("%d files." % input_paths.size)
|
435
|
+
open(output_path, "wb") do |f|
|
436
|
+
meta_data = {
|
437
|
+
:feature_names => Scene.feature_names,
|
438
|
+
}
|
439
|
+
Marshal.dump(meta_data, f)
|
440
|
+
@stored_kyokus = []
|
441
|
+
input_paths.enum_for(:each_with_progress).each_with_index() do |path, i|
|
442
|
+
if i % 100 == 0 && i > 0
|
443
|
+
Marshal.dump(@stored_kyokus, f)
|
444
|
+
@stored_kyokus.clear()
|
445
|
+
end
|
446
|
+
extract_features_from_file(path)
|
447
|
+
end
|
448
|
+
Marshal.dump(@stored_kyokus, f)
|
449
|
+
end
|
450
|
+
end
|
451
|
+
|
452
|
+
def extract_features_from_file(input_path)
|
453
|
+
begin
|
454
|
+
stored_kyoku = nil
|
455
|
+
reacher = nil
|
456
|
+
waited = nil
|
457
|
+
prereach_sutehais = nil
|
458
|
+
skip = false
|
459
|
+
archive = Archive.load(input_path)
|
460
|
+
archive.each_action() do |action|
|
461
|
+
archive.dump_action(action) if self.verbose
|
462
|
+
case action.type
|
463
|
+
|
464
|
+
when :start_kyoku
|
465
|
+
stored_kyoku = StoredKyoku.new([])
|
466
|
+
reacher = nil
|
467
|
+
skip = false
|
468
|
+
|
469
|
+
when :end_kyoku
|
470
|
+
next if skip
|
471
|
+
raise("should not happen") if !stored_kyoku
|
472
|
+
@stored_kyokus.push(stored_kyoku)
|
473
|
+
stored_kyoku = nil
|
474
|
+
|
475
|
+
when :reach_accepted
|
476
|
+
if ["ASAPIN", "(≧▽≦)"].include?(action.actor.name) || reacher
|
477
|
+
skip = true
|
478
|
+
end
|
479
|
+
next if skip
|
480
|
+
reacher = action.actor
|
481
|
+
waited = TenpaiAnalysis.new(action.actor.tehais).waited_pais
|
482
|
+
prereach_sutehais = reacher.sutehais.dup()
|
483
|
+
|
484
|
+
when :dahai
|
485
|
+
next if skip || !reacher || action.actor.reach?
|
486
|
+
scene = Scene.new(archive, action.actor, action.pai, reacher, prereach_sutehais)
|
487
|
+
stored_scene = StoredScene.new([])
|
488
|
+
#p [:candidates, action.actor, reacher, scene.candidates.join(" ")]
|
489
|
+
puts("reacher: %d" % reacher.id) if self.verbose
|
490
|
+
for pai in scene.candidates
|
491
|
+
hit = waited.include?(pai)
|
492
|
+
feature_vector = scene.feature_vector(pai)
|
493
|
+
stored_scene.candidates.push([feature_vector, hit])
|
494
|
+
if self.verbose
|
495
|
+
puts("candidate %s: hit=%d, %s" % [
|
496
|
+
pai,
|
497
|
+
hit ? 1 : 0,
|
498
|
+
DangerEstimator.feature_vector_to_str(feature_vector)])
|
499
|
+
end
|
500
|
+
end
|
501
|
+
stored_kyoku.scenes.push(stored_scene)
|
502
|
+
|
503
|
+
end
|
504
|
+
end
|
505
|
+
rescue Exception
|
506
|
+
$stderr.puts("at #{input_path}")
|
507
|
+
raise()
|
508
|
+
end
|
509
|
+
end
|
510
|
+
|
511
|
+
def calculate_single_probabilities(features_path)
|
512
|
+
criteria = Scene.feature_names.map(){ |s| [{s => false}, {s => true}] }.flatten()
|
513
|
+
calculate_probabilities(features_path, criteria)
|
514
|
+
end
|
515
|
+
|
516
|
+
def generate_decision_tree(features_path, base_criterion = {}, base_node = nil, root = nil)
|
517
|
+
p [:generate_decision_tree, base_criterion]
|
518
|
+
targets = {}
|
519
|
+
criteria = []
|
520
|
+
criteria.push(base_criterion) if !base_node
|
521
|
+
for name in Scene.feature_names
|
522
|
+
next if base_criterion.has_key?(name)
|
523
|
+
negative_criterion = base_criterion.merge({name => false})
|
524
|
+
positive_criterion = base_criterion.merge({name => true})
|
525
|
+
targets[name] = [negative_criterion, positive_criterion]
|
526
|
+
criteria.push(negative_criterion, positive_criterion)
|
527
|
+
end
|
528
|
+
node_map = calculate_probabilities(features_path, criteria)
|
529
|
+
base_node = node_map[base_criterion] if !base_node
|
530
|
+
root = base_node if !root
|
531
|
+
gaps = {}
|
532
|
+
for name, (negative_criterion, positive_criterion) in targets
|
533
|
+
negative = node_map[negative_criterion]
|
534
|
+
positive = node_map[positive_criterion]
|
535
|
+
next if !positive || !negative
|
536
|
+
if positive.average_prob >= negative.average_prob
|
537
|
+
gap = positive.conf_interval[0] - negative.conf_interval[1]
|
538
|
+
else
|
539
|
+
gap = negative.conf_interval[0] - positive.conf_interval[1]
|
540
|
+
end
|
541
|
+
p [name, gap]
|
542
|
+
gaps[name] = gap if gap > @min_gap
|
543
|
+
end
|
544
|
+
max_name = gaps.keys.max_by(){ |s| gaps[s] }
|
545
|
+
p [:max_name, max_name]
|
546
|
+
if max_name
|
547
|
+
(negative_criterion, positive_criterion) = targets[max_name]
|
548
|
+
base_node.feature_name = max_name
|
549
|
+
base_node.negative = node_map[negative_criterion]
|
550
|
+
base_node.positive = node_map[positive_criterion]
|
551
|
+
render_decision_tree(root, "all")
|
552
|
+
generate_decision_tree(features_path, negative_criterion, base_node.negative, root)
|
553
|
+
generate_decision_tree(features_path, positive_criterion, base_node.positive, root)
|
554
|
+
end
|
555
|
+
return base_node
|
556
|
+
end
|
557
|
+
|
558
|
+
def render_decision_tree(node, label, indent = 0)
|
559
|
+
puts("%s%s : %.2f [%.2f, %.2f] (%d samples)" %
|
560
|
+
[" " * indent,
|
561
|
+
label,
|
562
|
+
node.average_prob * 100.0,
|
563
|
+
node.conf_interval[0] * 100.0,
|
564
|
+
node.conf_interval[1] * 100.0,
|
565
|
+
node.num_samples])
|
566
|
+
if node.feature_name
|
567
|
+
for value, child in [[false, node.negative], [true, node.positive]].
|
568
|
+
sort_by(){ |v, c| c.average_prob }
|
569
|
+
render_decision_tree(child, "%s = %p" % [node.feature_name, value], indent + 1)
|
570
|
+
end
|
571
|
+
end
|
572
|
+
end
|
573
|
+
|
574
|
+
def calculate_probabilities(features_path, criteria)
|
575
|
+
create_kyoku_probs_map(features_path, criteria)
|
576
|
+
aggregate_pribabilities(criteria)
|
577
|
+
end
|
578
|
+
|
579
|
+
def create_kyoku_probs_map(features_path, criteria)
|
580
|
+
|
581
|
+
@kyoku_probs_map = {}
|
582
|
+
|
583
|
+
criterion_masks = {}
|
584
|
+
for criterion in criteria
|
585
|
+
positive_ary = [false] * Scene.feature_names.size
|
586
|
+
negative_ary = [true] * Scene.feature_names.size
|
587
|
+
for name, value in criterion
|
588
|
+
index = Scene.feature_names.index(name)
|
589
|
+
if value
|
590
|
+
positive_ary[index] = true
|
591
|
+
else
|
592
|
+
negative_ary[index] = false
|
593
|
+
end
|
594
|
+
end
|
595
|
+
criterion_masks[criterion] = [
|
596
|
+
DangerEstimator.bool_array_to_bit_vector(positive_ary),
|
597
|
+
DangerEstimator.bool_array_to_bit_vector(negative_ary),
|
598
|
+
]
|
599
|
+
end
|
600
|
+
|
601
|
+
open(features_path, "rb") do |f|
|
602
|
+
meta_data = Marshal.load(f)
|
603
|
+
if meta_data[:feature_names] != Scene.feature_names
|
604
|
+
raise("feature set has been changed")
|
605
|
+
end
|
606
|
+
f.with_progress() do
|
607
|
+
begin
|
608
|
+
while true
|
609
|
+
stored_kyokus = Marshal.load(f)
|
610
|
+
for stored_kyoku in stored_kyokus
|
611
|
+
update_metrics_for_kyoku(stored_kyoku, criterion_masks)
|
612
|
+
end
|
613
|
+
end
|
614
|
+
rescue EOFError
|
615
|
+
end
|
616
|
+
end
|
617
|
+
end
|
618
|
+
|
619
|
+
end
|
620
|
+
|
621
|
+
def aggregate_pribabilities(criteria)
|
622
|
+
result = {}
|
623
|
+
for criterion in criteria
|
624
|
+
kyoku_probs = @kyoku_probs_map[criterion.object_id]
|
625
|
+
next if !kyoku_probs
|
626
|
+
result[criterion] = node = DecisionNode.new(
|
627
|
+
kyoku_probs.inject(:+) / kyoku_probs.size,
|
628
|
+
confidence_interval(kyoku_probs),
|
629
|
+
kyoku_probs.size)
|
630
|
+
puts("%p\n %.2f [%.2f, %.2f] (%d samples)" %
|
631
|
+
[criterion,
|
632
|
+
node.average_prob * 100.0,
|
633
|
+
node.conf_interval[0] * 100.0,
|
634
|
+
node.conf_interval[1] * 100.0,
|
635
|
+
node.num_samples])
|
636
|
+
end
|
637
|
+
return result
|
638
|
+
end
|
639
|
+
|
640
|
+
def update_metrics_for_kyoku(stored_kyoku, criterion_masks)
|
641
|
+
scene_prob_sums = Hash.new(0.0)
|
642
|
+
scene_counts = Hash.new(0)
|
643
|
+
for stored_scene in stored_kyoku.scenes
|
644
|
+
pai_freqs = {}
|
645
|
+
for feature_vector, hit in stored_scene.candidates
|
646
|
+
for criterion, (positive_mask, negative_mask) in criterion_masks
|
647
|
+
if match?(feature_vector, positive_mask, negative_mask)
|
648
|
+
# Uses object_id as key for efficiency.
|
649
|
+
pai_freqs[criterion.object_id] ||= Hash.new(0)
|
650
|
+
pai_freqs[criterion.object_id][hit] += 1
|
651
|
+
end
|
652
|
+
end
|
653
|
+
#p [pai, hit, feature_vector]
|
654
|
+
end
|
655
|
+
for criterion_id, freqs in pai_freqs
|
656
|
+
scene_prob = freqs[true].to_f() / (freqs[false] + freqs[true])
|
657
|
+
#p [:scene_prob, criterion, scene_prob]
|
658
|
+
scene_prob_sums[criterion_id] += scene_prob
|
659
|
+
scene_counts[criterion_id] += 1
|
660
|
+
end
|
661
|
+
end
|
662
|
+
for criterion_id, count in scene_counts
|
663
|
+
kyoku_prob = scene_prob_sums[criterion_id] / count
|
664
|
+
#p [:kyoku_prob, criterion, kyoku_prob]
|
665
|
+
@kyoku_probs_map[criterion_id] ||= []
|
666
|
+
@kyoku_probs_map[criterion_id].push(kyoku_prob)
|
667
|
+
end
|
668
|
+
end
|
669
|
+
|
670
|
+
def match?(feature_vector, positive_mask, negative_mask)
|
671
|
+
return (feature_vector & positive_mask) == positive_mask &&
|
672
|
+
(feature_vector | negative_mask) == negative_mask
|
673
|
+
end
|
674
|
+
|
675
|
+
# Uses bootstrap resampling.
|
676
|
+
def confidence_interval(samples, conf_level = 0.95)
|
677
|
+
num_tries = 1000
|
678
|
+
averages = []
|
679
|
+
num_tries.times() do
|
680
|
+
sum = 0.0
|
681
|
+
(samples.size + 2).times() do
|
682
|
+
idx = rand(samples.size + 2)
|
683
|
+
case idx
|
684
|
+
when samples.size
|
685
|
+
sum += 0.0
|
686
|
+
when samples.size + 1
|
687
|
+
sum += 1.0
|
688
|
+
else
|
689
|
+
sum += samples[idx]
|
690
|
+
end
|
691
|
+
end
|
692
|
+
averages.push(sum / (samples.size + 2))
|
693
|
+
end
|
694
|
+
averages.sort!()
|
695
|
+
margin = (1.0 - conf_level) / 2
|
696
|
+
return [
|
697
|
+
averages[(num_tries * margin).to_i()],
|
698
|
+
averages[(num_tries * (1.0 - margin)).to_i()],
|
699
|
+
]
|
700
|
+
end
|
701
|
+
|
702
|
+
def self.bool_array_to_bit_vector(bool_array)
|
703
|
+
vector = 0
|
704
|
+
bool_array.reverse_each() do |value|
|
705
|
+
vector <<= 1
|
706
|
+
vector |= 1 if value
|
707
|
+
end
|
708
|
+
return vector
|
709
|
+
end
|
710
|
+
|
711
|
+
def self.feature_vector_to_str(feature_vector)
|
712
|
+
return (0...Scene.feature_names.size).select(){ |i| feature_vector[i] != 0 }.
|
713
|
+
map(){ |i| Scene.feature_names[i] }.join(" ")
|
714
|
+
end
|
715
|
+
|
716
|
+
def self.get_feature_value(feature_vector, feature_name)
|
717
|
+
return feature_vector[Scene.feature_names.index(feature_name)] != 0
|
718
|
+
end
|
719
|
+
|
720
|
+
end
|
721
|
+
|
722
|
+
|
723
|
+
end
|
724
|
+
|
725
|
+
end
|
726
|
+
|
727
|
+
|
728
|
+
# For compatibility.
|
729
|
+
# TODO Remove this.
|
730
|
+
DecisionNode = Mjai::Manue::DangerEstimator::DecisionNode
|