mjai-manue 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/bin/mjai-manue +10 -0
- data/lib/mjai/manue/danger_estimator.rb +730 -0
- data/lib/mjai/manue/hora_points_estimate.rb +488 -0
- data/lib/mjai/manue/hora_probability_estimator.rb +328 -0
- data/lib/mjai/manue/mjai_manue_command.rb +32 -0
- data/lib/mjai/manue/player.rb +373 -0
- data/share/danger.all.tree +0 -0
- data/share/hora_prob.marshal +0 -0
- metadata +65 -0
@@ -0,0 +1,730 @@
|
|
1
|
+
# coding: utf-8
|
2
|
+
|
3
|
+
require "set"
|
4
|
+
require "optparse"
|
5
|
+
|
6
|
+
require "mjai/pai"
|
7
|
+
|
8
|
+
|
9
|
+
module Mjai
|
10
|
+
|
11
|
+
module Manue
|
12
|
+
|
13
|
+
|
14
|
+
class DangerEstimator
|
15
|
+
|
16
|
+
class Scene
|
17
|
+
|
18
|
+
URASUJI_INV_MAP = {
|
19
|
+
1 => [5],
|
20
|
+
2 => [1, 6],
|
21
|
+
3 => [2, 7],
|
22
|
+
4 => [3, 5, 8],
|
23
|
+
5 => [1, 4, 6, 9],
|
24
|
+
6 => [2, 5, 7],
|
25
|
+
7 => [3, 8],
|
26
|
+
8 => [4, 9],
|
27
|
+
9 => [5],
|
28
|
+
}
|
29
|
+
|
30
|
+
SENKISUJI_INV_MAP = {
|
31
|
+
3 => [1, 8],
|
32
|
+
4 => [2, 9],
|
33
|
+
5 => [3, 7],
|
34
|
+
6 => [1, 8],
|
35
|
+
7 => [2, 9],
|
36
|
+
}
|
37
|
+
|
38
|
+
@@feature_names = []
|
39
|
+
|
40
|
+
def self.define_feature(name, &block)
|
41
|
+
define_method(name, &block)
|
42
|
+
@@feature_names.push(name)
|
43
|
+
end
|
44
|
+
|
45
|
+
def self.feature_names
|
46
|
+
return @@feature_names
|
47
|
+
end
|
48
|
+
|
49
|
+
def initialize(game, me, dapai, reacher, prereach_sutehais)
|
50
|
+
|
51
|
+
@game = game
|
52
|
+
@dapai = dapai
|
53
|
+
@me = me
|
54
|
+
@reacher = reacher
|
55
|
+
@prereach_sutehais = prereach_sutehais
|
56
|
+
|
57
|
+
tehais = @me.tehais.dup()
|
58
|
+
tehais.push(@dapai) if @dapai
|
59
|
+
@anpai_set = to_pai_set(reacher.anpais)
|
60
|
+
@prereach_sutehai_set = to_pai_set(@prereach_sutehais)
|
61
|
+
@early_sutehai_set = to_pai_set(@prereach_sutehais[0...(@prereach_sutehais.size / 2)])
|
62
|
+
@late_sutehai_set = to_pai_set(@prereach_sutehais[(@prereach_sutehais.size / 2)..-1])
|
63
|
+
@dora_set = to_pai_set(@game.doras)
|
64
|
+
@tehai_set = to_pai_set(tehais)
|
65
|
+
|
66
|
+
visible = []
|
67
|
+
visible += @game.doras
|
68
|
+
visible += @me.tehais
|
69
|
+
for player in @game.players
|
70
|
+
visible += player.ho + player.furos.map(){ |f| f.pais }.flatten()
|
71
|
+
end
|
72
|
+
@visible_set = to_pai_set(visible)
|
73
|
+
|
74
|
+
@candidates = tehais.
|
75
|
+
map(){ |pai| pai.remove_red() }.
|
76
|
+
uniq().
|
77
|
+
select(){ |pai| !@anpai_set.has_key?(pai) }
|
78
|
+
|
79
|
+
end
|
80
|
+
|
81
|
+
attr_reader(:candidates)
|
82
|
+
|
83
|
+
def to_pai_set(pais)
|
84
|
+
pai_set = Hash.new(0)
|
85
|
+
for pai in pais
|
86
|
+
pai_set[pai.remove_red()] += 1
|
87
|
+
end
|
88
|
+
return pai_set
|
89
|
+
end
|
90
|
+
|
91
|
+
# pai is without red.
|
92
|
+
# Use bit vector to make match? faster.
|
93
|
+
def feature_vector(pai)
|
94
|
+
return DangerEstimator.bool_array_to_bit_vector(
|
95
|
+
@@feature_names.map(){ |s| __send__(s, pai) })
|
96
|
+
end
|
97
|
+
|
98
|
+
def anpai?(pai)
|
99
|
+
return @anpai_set.has_key?(pai.remove_red())
|
100
|
+
end
|
101
|
+
|
102
|
+
define_feature("tsupai") do |pai|
|
103
|
+
return pai.type == "t"
|
104
|
+
end
|
105
|
+
|
106
|
+
define_feature("suji") do |pai|
|
107
|
+
if pai.type == "t"
|
108
|
+
return false
|
109
|
+
else
|
110
|
+
return get_suji_numbers(pai).all?(){ |n| @anpai_set.has_key?(Pai.new(pai.type, n)) }
|
111
|
+
end
|
112
|
+
end
|
113
|
+
|
114
|
+
# 片筋 or 筋
|
115
|
+
define_feature("weak_suji") do |pai|
|
116
|
+
return suji_of(pai, @anpai_set)
|
117
|
+
end
|
118
|
+
|
119
|
+
define_feature("reach_suji") do |pai|
|
120
|
+
reach_pai = @prereach_sutehais[-1].remove_red()
|
121
|
+
if pai.type == "t" || reach_pai.type != pai.type || pai.number == 1 || pai.number == 9
|
122
|
+
return false
|
123
|
+
else
|
124
|
+
suji_numbers = get_suji_numbers(pai)
|
125
|
+
return suji_numbers.all?(){ |n| @prereach_sutehai_set.include?(Pai.new(pai.type, n)) } &&
|
126
|
+
suji_numbers.include?(reach_pai.number) &&
|
127
|
+
@prereach_sutehai_set[reach_pai] == 1
|
128
|
+
end
|
129
|
+
end
|
130
|
+
|
131
|
+
# http://ja.wikipedia.org/wiki/%E7%AD%8B_(%E9%BA%BB%E9%9B%80)#.E8.A3.8F.E3.82.B9.E3.82.B8
|
132
|
+
define_feature("urasuji") do |pai|
|
133
|
+
return urasuji_of(pai, @prereach_sutehai_set)
|
134
|
+
end
|
135
|
+
|
136
|
+
define_feature("early_urasuji") do |pai|
|
137
|
+
return urasuji_of(pai, @early_sutehai_set)
|
138
|
+
end
|
139
|
+
|
140
|
+
define_feature("reach_urasuji") do |pai|
|
141
|
+
return urasuji_of(pai, to_pai_set([self.reach_pai]))
|
142
|
+
end
|
143
|
+
|
144
|
+
# http://ja.wikipedia.org/wiki/%E7%AD%8B_(%E9%BA%BB%E9%9B%80)#.E9.96.93.E5.9B.9B.E9.96.93
|
145
|
+
define_feature("aida4ken") do |pai|
|
146
|
+
if pai.type == "t"
|
147
|
+
return false
|
148
|
+
else
|
149
|
+
return ((2..5).include?(pai.number) &&
|
150
|
+
@prereach_sutehai_set.has_key?(Pai.new(pai.type, pai.number - 1)) &&
|
151
|
+
@prereach_sutehai_set.has_key?(Pai.new(pai.type, pai.number + 4))) ||
|
152
|
+
((5..8).include?(pai.number) &&
|
153
|
+
@prereach_sutehai_set.has_key?(Pai.new(pai.type, pai.number - 4)) &&
|
154
|
+
@prereach_sutehai_set.has_key?(Pai.new(pai.type, pai.number + 1)))
|
155
|
+
end
|
156
|
+
end
|
157
|
+
|
158
|
+
# http://ja.wikipedia.org/wiki/%E7%AD%8B_(%E9%BA%BB%E9%9B%80)#.E3.81.BE.E3.81.9F.E3.81.8E.E3.82.B9.E3.82.B8
|
159
|
+
define_feature("matagisuji") do |pai|
|
160
|
+
return matagisuji_of(pai, @prereach_sutehai_set)
|
161
|
+
end
|
162
|
+
|
163
|
+
define_feature("late_matagisuji") do |pai|
|
164
|
+
return matagisuji_of(pai, @late_sutehai_set)
|
165
|
+
end
|
166
|
+
|
167
|
+
# http://ja.wikipedia.org/wiki/%E7%AD%8B_(%E9%BA%BB%E9%9B%80)#.E7.96.9D.E6.B0.97.E3.82.B9.E3.82.B8
|
168
|
+
define_feature("senkisuji") do |pai|
|
169
|
+
return senkisuji_of(pai, @prereach_sutehai_set)
|
170
|
+
end
|
171
|
+
|
172
|
+
define_feature("early_senkisuji") do |pai|
|
173
|
+
return senkisuji_of(pai, @early_sutehai_set)
|
174
|
+
end
|
175
|
+
|
176
|
+
define_feature("outer_prereach_sutehai") do |pai|
|
177
|
+
return outer(pai, @prereach_sutehai_set)
|
178
|
+
end
|
179
|
+
|
180
|
+
define_feature("outer_early_sutehai") do |pai|
|
181
|
+
return outer(pai, @early_sutehai_set)
|
182
|
+
end
|
183
|
+
|
184
|
+
(0..3).each() do |n|
|
185
|
+
define_feature("chances<=#{n}") do |pai|
|
186
|
+
return n_chance_or_less(pai, n)
|
187
|
+
end
|
188
|
+
end
|
189
|
+
|
190
|
+
(1..3).each() do |i|
|
191
|
+
define_feature("visible>=%d" % i) do |pai|
|
192
|
+
return visible_n_or_more(pai, i)
|
193
|
+
end
|
194
|
+
end
|
195
|
+
|
196
|
+
(2..5).each() do |i|
|
197
|
+
define_feature("%d<=n<=%d" % [i, 10 - i]) do |pai|
|
198
|
+
return num_n_or_inner(pai, i)
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
define_feature("dora") do |pai|
|
203
|
+
return @dora_set.has_key?(pai)
|
204
|
+
end
|
205
|
+
|
206
|
+
define_feature("dora_suji") do |pai|
|
207
|
+
return suji_of(pai, @dora_set)
|
208
|
+
end
|
209
|
+
|
210
|
+
define_feature("dora_matagi") do |pai|
|
211
|
+
return matagisuji_of(pai, @dora_set)
|
212
|
+
end
|
213
|
+
|
214
|
+
(2..4).each() do |i|
|
215
|
+
define_feature("in_tehais>=#{i}") do |pai|
|
216
|
+
return @tehai_set[pai] >= i
|
217
|
+
end
|
218
|
+
end
|
219
|
+
|
220
|
+
(2..4).each() do |i|
|
221
|
+
define_feature("suji_in_tehais>=#{i}") do |pai|
|
222
|
+
if pai.type == "t"
|
223
|
+
return false
|
224
|
+
else
|
225
|
+
return get_suji_numbers(pai).any?(){ |n| @tehai_set[Pai.new(pai.type, n)] >= i }
|
226
|
+
end
|
227
|
+
end
|
228
|
+
end
|
229
|
+
|
230
|
+
(1..2).each() do |i|
|
231
|
+
(1..(i * 2)).each() do |j|
|
232
|
+
define_feature("+-#{i}_in_prereach_sutehais>=#{j}") do |pai|
|
233
|
+
n_or_more_of_neighbors_in_prereach_sutehais(pai, j, i)
|
234
|
+
end
|
235
|
+
end
|
236
|
+
end
|
237
|
+
|
238
|
+
(1..2).each() do |i|
|
239
|
+
define_feature("#{i}_outer_prereach_sutehai") do |pai|
|
240
|
+
n_outer_prereach_sutehai(pai, i)
|
241
|
+
end
|
242
|
+
end
|
243
|
+
|
244
|
+
(1..2).each() do |i|
|
245
|
+
define_feature("#{i}_inner_prereach_sutehai") do |pai|
|
246
|
+
n_outer_prereach_sutehai(pai, -i)
|
247
|
+
end
|
248
|
+
end
|
249
|
+
|
250
|
+
(1..8).each() do |i|
|
251
|
+
define_feature("same_type_in_prereach>=#{i}") do |pai|
|
252
|
+
if pai.type == "t"
|
253
|
+
return false
|
254
|
+
else
|
255
|
+
num_same_type = (1..9).
|
256
|
+
select(){ |n| @prereach_sutehai_set.has_key?(Pai.new(pai.type, n)) }.
|
257
|
+
size
|
258
|
+
return num_same_type >= i
|
259
|
+
end
|
260
|
+
end
|
261
|
+
end
|
262
|
+
|
263
|
+
define_feature("fanpai") do |pai|
|
264
|
+
return fanpai_fansu(pai) >= 1
|
265
|
+
end
|
266
|
+
|
267
|
+
define_feature("ryenfonpai") do |pai|
|
268
|
+
return fanpai_fansu(pai) >= 2
|
269
|
+
end
|
270
|
+
|
271
|
+
define_feature("sangenpai") do |pai|
|
272
|
+
return pai.type == "t" && pai.number >= 5
|
273
|
+
end
|
274
|
+
|
275
|
+
define_feature("fonpai") do |pai|
|
276
|
+
return pai.type == "t" && pai.number < 5
|
277
|
+
end
|
278
|
+
|
279
|
+
define_feature("bakaze") do |pai|
|
280
|
+
return pai == @game.bakaze
|
281
|
+
end
|
282
|
+
|
283
|
+
define_feature("jikaze") do |pai|
|
284
|
+
return pai == @reacher.jikaze
|
285
|
+
end
|
286
|
+
|
287
|
+
def n_outer_prereach_sutehai(pai, n)
|
288
|
+
if pai.type == "t"
|
289
|
+
return false
|
290
|
+
elsif pai.number < 6 - n || pai.number > 4 + n
|
291
|
+
n_inner_pai = Pai.new(pai.type, pai.number < 5 ? pai.number + n : pai.number - n)
|
292
|
+
return @prereach_sutehai_set.include?(n_inner_pai)
|
293
|
+
else
|
294
|
+
return false
|
295
|
+
end
|
296
|
+
end
|
297
|
+
|
298
|
+
def n_or_more_of_neighbors_in_prereach_sutehais(pai, n, neighbor_distance)
|
299
|
+
if pai.type == "t"
|
300
|
+
return false
|
301
|
+
else
|
302
|
+
num_neighbors =
|
303
|
+
((pai.number - neighbor_distance)..(pai.number + neighbor_distance)).
|
304
|
+
select(){ |n| @prereach_sutehai_set.has_key?(Pai.new(pai.type, n)) }.
|
305
|
+
size
|
306
|
+
return num_neighbors >= n
|
307
|
+
end
|
308
|
+
end
|
309
|
+
|
310
|
+
def suji_of(pai, target_pai_set)
|
311
|
+
if pai.type == "t"
|
312
|
+
return false
|
313
|
+
else
|
314
|
+
return get_suji_numbers(pai).any?(){ |n| target_pai_set.has_key?(Pai.new(pai.type, n)) }
|
315
|
+
end
|
316
|
+
end
|
317
|
+
|
318
|
+
def get_suji_numbers(pai)
|
319
|
+
return [pai.number - 3, pai.number + 3].select(){ |n| (1..9).include?(n) }
|
320
|
+
end
|
321
|
+
|
322
|
+
def n_chance_or_less(pai, n)
|
323
|
+
if pai.type == "t" || (4..6).include?(pai.number)
|
324
|
+
return false
|
325
|
+
else
|
326
|
+
return (1..2).any?() do |i|
|
327
|
+
kabe_pai = Pai.new(pai.type, pai.number + (pai.number < 5 ? i : -i))
|
328
|
+
@visible_set[kabe_pai] >= 4 - n
|
329
|
+
end
|
330
|
+
end
|
331
|
+
end
|
332
|
+
|
333
|
+
def num_n_or_inner(pai, n)
|
334
|
+
return pai.type != "t" && pai.number >= n && pai.number <= 10 - n
|
335
|
+
end
|
336
|
+
|
337
|
+
def visible_n_or_more(pai, n)
|
338
|
+
# n doesn't include itself.
|
339
|
+
return @visible_set[pai] >= n + 1
|
340
|
+
end
|
341
|
+
|
342
|
+
def urasuji_of(pai, target_pai_set)
|
343
|
+
if pai.type == "t"
|
344
|
+
return false
|
345
|
+
else
|
346
|
+
urasuji_numbers = URASUJI_INV_MAP[pai.number]
|
347
|
+
return urasuji_numbers.any?(){ |n| target_pai_set.has_key?(Pai.new(pai.type, n)) }
|
348
|
+
end
|
349
|
+
end
|
350
|
+
|
351
|
+
def senkisuji_of(pai, target_pai_set)
|
352
|
+
if pai.type == "t"
|
353
|
+
return false
|
354
|
+
else
|
355
|
+
senkisuji_numbers = SENKISUJI_INV_MAP[pai.number]
|
356
|
+
return senkisuji_numbers &&
|
357
|
+
senkisuji_numbers.any?(){ |n| target_pai_set.has_key?(Pai.new(pai.type, n)) }
|
358
|
+
end
|
359
|
+
end
|
360
|
+
|
361
|
+
def matagisuji_of(pai, target_pai_set)
|
362
|
+
if pai.type == "t"
|
363
|
+
return false
|
364
|
+
else
|
365
|
+
matagisuji_numbers = []
|
366
|
+
if pai.number >= 4
|
367
|
+
matagisuji_numbers += [pai.number - 2, pai.number - 1]
|
368
|
+
end
|
369
|
+
if pai.number <= 6
|
370
|
+
matagisuji_numbers += [pai.number + 1, pai.number + 2]
|
371
|
+
end
|
372
|
+
return matagisuji_numbers.any?(){ |n| target_pai_set.has_key?(Pai.new(pai.type, n)) }
|
373
|
+
end
|
374
|
+
end
|
375
|
+
|
376
|
+
def outer(pai, target_pai_set)
|
377
|
+
if pai.type == "t" || pai.number == 5
|
378
|
+
return false
|
379
|
+
else
|
380
|
+
inner_numbers = pai.number < 5 ? ((pai.number + 1)..5) : (5..(pai.number - 1))
|
381
|
+
return inner_numbers.any?(){ |n| target_pai_set.has_key?(Pai.new(pai.type, n)) }
|
382
|
+
end
|
383
|
+
end
|
384
|
+
|
385
|
+
def reach_pai
|
386
|
+
return @prereach_sutehais[-1]
|
387
|
+
end
|
388
|
+
|
389
|
+
def fanpai_fansu(pai)
|
390
|
+
if pai.type == "t" && pai.number >= 5
|
391
|
+
return 1
|
392
|
+
else
|
393
|
+
return (pai == @game.bakaze ? 1 : 0) + (pai == @reacher.jikaze ? 1 : 0)
|
394
|
+
end
|
395
|
+
end
|
396
|
+
|
397
|
+
end
|
398
|
+
|
399
|
+
StoredKyoku = Struct.new(:scenes)
|
400
|
+
StoredScene = Struct.new(:candidates)
|
401
|
+
DecisionNode = Struct.new(
|
402
|
+
:average_prob, :conf_interval, :num_samples, :feature_name, :positive, :negative)
|
403
|
+
|
404
|
+
class DecisionTree
|
405
|
+
|
406
|
+
def initialize(path)
|
407
|
+
@root = open(path, "rb"){ |f| Marshal.load(f) }
|
408
|
+
end
|
409
|
+
|
410
|
+
def estimate_prob(scene, pai)
|
411
|
+
feature_vector = scene.feature_vector(pai.remove_red())
|
412
|
+
p [pai, DangerEstimator.feature_vector_to_str(feature_vector)]
|
413
|
+
node = @root
|
414
|
+
while node.feature_name
|
415
|
+
if DangerEstimator.get_feature_value(feature_vector, node.feature_name)
|
416
|
+
node = node.positive
|
417
|
+
else
|
418
|
+
node = node.negative
|
419
|
+
end
|
420
|
+
end
|
421
|
+
return node.average_prob
|
422
|
+
end
|
423
|
+
|
424
|
+
end
|
425
|
+
|
426
|
+
def initialize()
|
427
|
+
@min_gap = 0.0
|
428
|
+
end
|
429
|
+
|
430
|
+
attr_accessor(:verbose)
|
431
|
+
attr_accessor(:min_gap)
|
432
|
+
|
433
|
+
def extract_features_from_files(input_paths, output_path)
|
434
|
+
$stderr.puts("%d files." % input_paths.size)
|
435
|
+
open(output_path, "wb") do |f|
|
436
|
+
meta_data = {
|
437
|
+
:feature_names => Scene.feature_names,
|
438
|
+
}
|
439
|
+
Marshal.dump(meta_data, f)
|
440
|
+
@stored_kyokus = []
|
441
|
+
input_paths.enum_for(:each_with_progress).each_with_index() do |path, i|
|
442
|
+
if i % 100 == 0 && i > 0
|
443
|
+
Marshal.dump(@stored_kyokus, f)
|
444
|
+
@stored_kyokus.clear()
|
445
|
+
end
|
446
|
+
extract_features_from_file(path)
|
447
|
+
end
|
448
|
+
Marshal.dump(@stored_kyokus, f)
|
449
|
+
end
|
450
|
+
end
|
451
|
+
|
452
|
+
def extract_features_from_file(input_path)
|
453
|
+
begin
|
454
|
+
stored_kyoku = nil
|
455
|
+
reacher = nil
|
456
|
+
waited = nil
|
457
|
+
prereach_sutehais = nil
|
458
|
+
skip = false
|
459
|
+
archive = Archive.load(input_path)
|
460
|
+
archive.each_action() do |action|
|
461
|
+
archive.dump_action(action) if self.verbose
|
462
|
+
case action.type
|
463
|
+
|
464
|
+
when :start_kyoku
|
465
|
+
stored_kyoku = StoredKyoku.new([])
|
466
|
+
reacher = nil
|
467
|
+
skip = false
|
468
|
+
|
469
|
+
when :end_kyoku
|
470
|
+
next if skip
|
471
|
+
raise("should not happen") if !stored_kyoku
|
472
|
+
@stored_kyokus.push(stored_kyoku)
|
473
|
+
stored_kyoku = nil
|
474
|
+
|
475
|
+
when :reach_accepted
|
476
|
+
if ["ASAPIN", "(≧▽≦)"].include?(action.actor.name) || reacher
|
477
|
+
skip = true
|
478
|
+
end
|
479
|
+
next if skip
|
480
|
+
reacher = action.actor
|
481
|
+
waited = TenpaiAnalysis.new(action.actor.tehais).waited_pais
|
482
|
+
prereach_sutehais = reacher.sutehais.dup()
|
483
|
+
|
484
|
+
when :dahai
|
485
|
+
next if skip || !reacher || action.actor.reach?
|
486
|
+
scene = Scene.new(archive, action.actor, action.pai, reacher, prereach_sutehais)
|
487
|
+
stored_scene = StoredScene.new([])
|
488
|
+
#p [:candidates, action.actor, reacher, scene.candidates.join(" ")]
|
489
|
+
puts("reacher: %d" % reacher.id) if self.verbose
|
490
|
+
for pai in scene.candidates
|
491
|
+
hit = waited.include?(pai)
|
492
|
+
feature_vector = scene.feature_vector(pai)
|
493
|
+
stored_scene.candidates.push([feature_vector, hit])
|
494
|
+
if self.verbose
|
495
|
+
puts("candidate %s: hit=%d, %s" % [
|
496
|
+
pai,
|
497
|
+
hit ? 1 : 0,
|
498
|
+
DangerEstimator.feature_vector_to_str(feature_vector)])
|
499
|
+
end
|
500
|
+
end
|
501
|
+
stored_kyoku.scenes.push(stored_scene)
|
502
|
+
|
503
|
+
end
|
504
|
+
end
|
505
|
+
rescue Exception
|
506
|
+
$stderr.puts("at #{input_path}")
|
507
|
+
raise()
|
508
|
+
end
|
509
|
+
end
|
510
|
+
|
511
|
+
def calculate_single_probabilities(features_path)
|
512
|
+
criteria = Scene.feature_names.map(){ |s| [{s => false}, {s => true}] }.flatten()
|
513
|
+
calculate_probabilities(features_path, criteria)
|
514
|
+
end
|
515
|
+
|
516
|
+
def generate_decision_tree(features_path, base_criterion = {}, base_node = nil, root = nil)
|
517
|
+
p [:generate_decision_tree, base_criterion]
|
518
|
+
targets = {}
|
519
|
+
criteria = []
|
520
|
+
criteria.push(base_criterion) if !base_node
|
521
|
+
for name in Scene.feature_names
|
522
|
+
next if base_criterion.has_key?(name)
|
523
|
+
negative_criterion = base_criterion.merge({name => false})
|
524
|
+
positive_criterion = base_criterion.merge({name => true})
|
525
|
+
targets[name] = [negative_criterion, positive_criterion]
|
526
|
+
criteria.push(negative_criterion, positive_criterion)
|
527
|
+
end
|
528
|
+
node_map = calculate_probabilities(features_path, criteria)
|
529
|
+
base_node = node_map[base_criterion] if !base_node
|
530
|
+
root = base_node if !root
|
531
|
+
gaps = {}
|
532
|
+
for name, (negative_criterion, positive_criterion) in targets
|
533
|
+
negative = node_map[negative_criterion]
|
534
|
+
positive = node_map[positive_criterion]
|
535
|
+
next if !positive || !negative
|
536
|
+
if positive.average_prob >= negative.average_prob
|
537
|
+
gap = positive.conf_interval[0] - negative.conf_interval[1]
|
538
|
+
else
|
539
|
+
gap = negative.conf_interval[0] - positive.conf_interval[1]
|
540
|
+
end
|
541
|
+
p [name, gap]
|
542
|
+
gaps[name] = gap if gap > @min_gap
|
543
|
+
end
|
544
|
+
max_name = gaps.keys.max_by(){ |s| gaps[s] }
|
545
|
+
p [:max_name, max_name]
|
546
|
+
if max_name
|
547
|
+
(negative_criterion, positive_criterion) = targets[max_name]
|
548
|
+
base_node.feature_name = max_name
|
549
|
+
base_node.negative = node_map[negative_criterion]
|
550
|
+
base_node.positive = node_map[positive_criterion]
|
551
|
+
render_decision_tree(root, "all")
|
552
|
+
generate_decision_tree(features_path, negative_criterion, base_node.negative, root)
|
553
|
+
generate_decision_tree(features_path, positive_criterion, base_node.positive, root)
|
554
|
+
end
|
555
|
+
return base_node
|
556
|
+
end
|
557
|
+
|
558
|
+
def render_decision_tree(node, label, indent = 0)
|
559
|
+
puts("%s%s : %.2f [%.2f, %.2f] (%d samples)" %
|
560
|
+
[" " * indent,
|
561
|
+
label,
|
562
|
+
node.average_prob * 100.0,
|
563
|
+
node.conf_interval[0] * 100.0,
|
564
|
+
node.conf_interval[1] * 100.0,
|
565
|
+
node.num_samples])
|
566
|
+
if node.feature_name
|
567
|
+
for value, child in [[false, node.negative], [true, node.positive]].
|
568
|
+
sort_by(){ |v, c| c.average_prob }
|
569
|
+
render_decision_tree(child, "%s = %p" % [node.feature_name, value], indent + 1)
|
570
|
+
end
|
571
|
+
end
|
572
|
+
end
|
573
|
+
|
574
|
+
def calculate_probabilities(features_path, criteria)
|
575
|
+
create_kyoku_probs_map(features_path, criteria)
|
576
|
+
aggregate_pribabilities(criteria)
|
577
|
+
end
|
578
|
+
|
579
|
+
def create_kyoku_probs_map(features_path, criteria)
|
580
|
+
|
581
|
+
@kyoku_probs_map = {}
|
582
|
+
|
583
|
+
criterion_masks = {}
|
584
|
+
for criterion in criteria
|
585
|
+
positive_ary = [false] * Scene.feature_names.size
|
586
|
+
negative_ary = [true] * Scene.feature_names.size
|
587
|
+
for name, value in criterion
|
588
|
+
index = Scene.feature_names.index(name)
|
589
|
+
if value
|
590
|
+
positive_ary[index] = true
|
591
|
+
else
|
592
|
+
negative_ary[index] = false
|
593
|
+
end
|
594
|
+
end
|
595
|
+
criterion_masks[criterion] = [
|
596
|
+
DangerEstimator.bool_array_to_bit_vector(positive_ary),
|
597
|
+
DangerEstimator.bool_array_to_bit_vector(negative_ary),
|
598
|
+
]
|
599
|
+
end
|
600
|
+
|
601
|
+
open(features_path, "rb") do |f|
|
602
|
+
meta_data = Marshal.load(f)
|
603
|
+
if meta_data[:feature_names] != Scene.feature_names
|
604
|
+
raise("feature set has been changed")
|
605
|
+
end
|
606
|
+
f.with_progress() do
|
607
|
+
begin
|
608
|
+
while true
|
609
|
+
stored_kyokus = Marshal.load(f)
|
610
|
+
for stored_kyoku in stored_kyokus
|
611
|
+
update_metrics_for_kyoku(stored_kyoku, criterion_masks)
|
612
|
+
end
|
613
|
+
end
|
614
|
+
rescue EOFError
|
615
|
+
end
|
616
|
+
end
|
617
|
+
end
|
618
|
+
|
619
|
+
end
|
620
|
+
|
621
|
+
def aggregate_pribabilities(criteria)
|
622
|
+
result = {}
|
623
|
+
for criterion in criteria
|
624
|
+
kyoku_probs = @kyoku_probs_map[criterion.object_id]
|
625
|
+
next if !kyoku_probs
|
626
|
+
result[criterion] = node = DecisionNode.new(
|
627
|
+
kyoku_probs.inject(:+) / kyoku_probs.size,
|
628
|
+
confidence_interval(kyoku_probs),
|
629
|
+
kyoku_probs.size)
|
630
|
+
puts("%p\n %.2f [%.2f, %.2f] (%d samples)" %
|
631
|
+
[criterion,
|
632
|
+
node.average_prob * 100.0,
|
633
|
+
node.conf_interval[0] * 100.0,
|
634
|
+
node.conf_interval[1] * 100.0,
|
635
|
+
node.num_samples])
|
636
|
+
end
|
637
|
+
return result
|
638
|
+
end
|
639
|
+
|
640
|
+
def update_metrics_for_kyoku(stored_kyoku, criterion_masks)
|
641
|
+
scene_prob_sums = Hash.new(0.0)
|
642
|
+
scene_counts = Hash.new(0)
|
643
|
+
for stored_scene in stored_kyoku.scenes
|
644
|
+
pai_freqs = {}
|
645
|
+
for feature_vector, hit in stored_scene.candidates
|
646
|
+
for criterion, (positive_mask, negative_mask) in criterion_masks
|
647
|
+
if match?(feature_vector, positive_mask, negative_mask)
|
648
|
+
# Uses object_id as key for efficiency.
|
649
|
+
pai_freqs[criterion.object_id] ||= Hash.new(0)
|
650
|
+
pai_freqs[criterion.object_id][hit] += 1
|
651
|
+
end
|
652
|
+
end
|
653
|
+
#p [pai, hit, feature_vector]
|
654
|
+
end
|
655
|
+
for criterion_id, freqs in pai_freqs
|
656
|
+
scene_prob = freqs[true].to_f() / (freqs[false] + freqs[true])
|
657
|
+
#p [:scene_prob, criterion, scene_prob]
|
658
|
+
scene_prob_sums[criterion_id] += scene_prob
|
659
|
+
scene_counts[criterion_id] += 1
|
660
|
+
end
|
661
|
+
end
|
662
|
+
for criterion_id, count in scene_counts
|
663
|
+
kyoku_prob = scene_prob_sums[criterion_id] / count
|
664
|
+
#p [:kyoku_prob, criterion, kyoku_prob]
|
665
|
+
@kyoku_probs_map[criterion_id] ||= []
|
666
|
+
@kyoku_probs_map[criterion_id].push(kyoku_prob)
|
667
|
+
end
|
668
|
+
end
|
669
|
+
|
670
|
+
def match?(feature_vector, positive_mask, negative_mask)
|
671
|
+
return (feature_vector & positive_mask) == positive_mask &&
|
672
|
+
(feature_vector | negative_mask) == negative_mask
|
673
|
+
end
|
674
|
+
|
675
|
+
# Uses bootstrap resampling.
|
676
|
+
def confidence_interval(samples, conf_level = 0.95)
|
677
|
+
num_tries = 1000
|
678
|
+
averages = []
|
679
|
+
num_tries.times() do
|
680
|
+
sum = 0.0
|
681
|
+
(samples.size + 2).times() do
|
682
|
+
idx = rand(samples.size + 2)
|
683
|
+
case idx
|
684
|
+
when samples.size
|
685
|
+
sum += 0.0
|
686
|
+
when samples.size + 1
|
687
|
+
sum += 1.0
|
688
|
+
else
|
689
|
+
sum += samples[idx]
|
690
|
+
end
|
691
|
+
end
|
692
|
+
averages.push(sum / (samples.size + 2))
|
693
|
+
end
|
694
|
+
averages.sort!()
|
695
|
+
margin = (1.0 - conf_level) / 2
|
696
|
+
return [
|
697
|
+
averages[(num_tries * margin).to_i()],
|
698
|
+
averages[(num_tries * (1.0 - margin)).to_i()],
|
699
|
+
]
|
700
|
+
end
|
701
|
+
|
702
|
+
def self.bool_array_to_bit_vector(bool_array)
|
703
|
+
vector = 0
|
704
|
+
bool_array.reverse_each() do |value|
|
705
|
+
vector <<= 1
|
706
|
+
vector |= 1 if value
|
707
|
+
end
|
708
|
+
return vector
|
709
|
+
end
|
710
|
+
|
711
|
+
def self.feature_vector_to_str(feature_vector)
|
712
|
+
return (0...Scene.feature_names.size).select(){ |i| feature_vector[i] != 0 }.
|
713
|
+
map(){ |i| Scene.feature_names[i] }.join(" ")
|
714
|
+
end
|
715
|
+
|
716
|
+
def self.get_feature_value(feature_vector, feature_name)
|
717
|
+
return feature_vector[Scene.feature_names.index(feature_name)] != 0
|
718
|
+
end
|
719
|
+
|
720
|
+
end
|
721
|
+
|
722
|
+
|
723
|
+
end
|
724
|
+
|
725
|
+
end
|
726
|
+
|
727
|
+
|
728
|
+
# For compatibility.
|
729
|
+
# TODO Remove this.
|
730
|
+
DecisionNode = Mjai::Manue::DangerEstimator::DecisionNode
|