winston 0.0.2 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,529 @@
1
+ require "set"
2
+
3
+ module Winston
4
+ module Solvers
5
+ class MAC
6
+ def initialize(csp, variable_strategy: :mrv, value_strategy: :in_order)
7
+ @csp = csp
8
+ @variable_strategy = variable_strategy
9
+ @value_strategy = value_strategy
10
+ @neighbors = build_neighbors
11
+ @binary_constraints = build_binary_constraints
12
+ @binary_predicates = build_binary_predicates
13
+ @unary_constraints = build_unary_constraints
14
+ @constraints_for_var = build_constraints_for_var
15
+ @all_different_constraints = build_all_different_constraints
16
+ @all_different_propagator = AllDifferentPropagator.new
17
+ end
18
+
19
+ def search(assignments = {})
20
+ domains = initial_domains(assignments)
21
+ return false if domains.nil?
22
+
23
+ trail = []
24
+ return false unless propagate_all_different(domains, trail)
25
+ return false unless ac3(domains, assignments, all_arcs, trail)
26
+ return false unless propagate_all_different(domains, trail)
27
+
28
+ search_with_domains(assignments, domains, trail)
29
+ end
30
+
31
+ private
32
+
33
+ attr_reader :csp,
34
+ :variable_strategy,
35
+ :value_strategy,
36
+ :neighbors,
37
+ :binary_constraints,
38
+ :binary_predicates,
39
+ :unary_constraints,
40
+ :constraints_for_var,
41
+ :all_different_constraints,
42
+ :all_different_propagator
43
+
44
+ def search_with_domains(assignments, domains, trail)
45
+ return assignments if complete?(assignments)
46
+
47
+ var = select_unassigned_variable(assignments, domains)
48
+ values_for(var, domains, assignments).each do |value|
49
+ assigned = assignments.merge(var.name => value)
50
+ next unless consistent_with_constraints(var.name, value, assignments)
51
+
52
+ mark = trail.length
53
+ assign_domain(var.name, value, domains, trail)
54
+ next unless propagate_all_different(domains, trail)
55
+ next unless ac3(domains, assigned, arcs_from(var.name), trail)
56
+ next unless propagate_all_different(domains, trail)
57
+
58
+ result = search_with_domains(assigned, domains, trail)
59
+ return result if result
60
+
61
+ restore(domains, trail, mark)
62
+ end
63
+
64
+ false
65
+ end
66
+
67
+ def complete?(assignments)
68
+ assignments.size == csp.variables.size
69
+ end
70
+
71
+ def initial_domains(assignments)
72
+ domains = {}
73
+ csp.variables.each do |name, variable|
74
+ if assignments.key?(name)
75
+ domains[name] = [assignments[name]]
76
+ elsif !variable.value.nil?
77
+ domains[name] = [variable.value]
78
+ else
79
+ values = csp.domain_for(name)
80
+ return nil if values.empty?
81
+ domains[name] = values.dup
82
+ end
83
+ end
84
+
85
+ prune_unary(domains, assignments) ? domains : nil
86
+ end
87
+
88
+ def prune_unary(domains, assignments)
89
+ unary_constraints.each do |name, constraints|
90
+ next unless domains.key?(name)
91
+ allowed = domains[name].select do |value|
92
+ with_assignment(assignments, name, value) do
93
+ constraints.all? { |constraint| constraint.validate(assignments) }
94
+ end
95
+ end
96
+ return false if allowed.empty?
97
+
98
+ domains[name] = allowed
99
+ end
100
+ true
101
+ end
102
+
103
+ def select_unassigned_variable(assignments, domains)
104
+ vars = unassigned_variables(assignments)
105
+ return vars.first if variable_strategy == :first
106
+ return variable_strategy.call(vars, assignments, csp) if variable_strategy.respond_to?(:call)
107
+ return vars.min_by { |var| domains[var.name].size } if variable_strategy == :mrv
108
+
109
+ vars.first
110
+ end
111
+
112
+ def values_for(var, domains, assignments)
113
+ values = domains[var.name]
114
+ return values if value_strategy == :in_order
115
+ return value_strategy.call(values, var, assignments, csp) if value_strategy.respond_to?(:call)
116
+ return order_lcv(values, var, assignments, domains) if value_strategy == :lcv
117
+
118
+ values
119
+ end
120
+
121
+ def order_lcv(values, var, assignments, domains)
122
+ other_vars = unassigned_variables(assignments).reject { |v| v.name == var.name }
123
+ scored = values.map do |value|
124
+ score = other_vars.sum do |other|
125
+ domains[other.name].count do |other_value|
126
+ with_assignment(assignments, var.name, value) do
127
+ with_assignment(assignments, other.name, other_value) do
128
+ csp.validate(other.name, assignments)
129
+ end
130
+ end
131
+ end
132
+ end
133
+ [value, score]
134
+ end
135
+ scored.sort_by { |(_, score)| -score }.map(&:first)
136
+ end
137
+
138
+ def unassigned_variables(assignments)
139
+ csp.variables.reject { |k, _| assignments.include?(k) }.each_value.to_a
140
+ end
141
+
142
+ def ac3(domains, assignments, queue = nil, trail = nil)
143
+ queue ||= all_arcs
144
+ queue = queue.dup
145
+ idx = 0
146
+ while idx < queue.length
147
+ xi, xj = queue[idx]
148
+ idx += 1
149
+ if revise(domains, assignments, xi, xj, trail)
150
+ return false if domains[xi].empty?
151
+ neighbors[xi].each do |xk|
152
+ next if xk == xj
153
+ queue << [xk, xi]
154
+ end
155
+ end
156
+ end
157
+ true
158
+ end
159
+
160
+ def revise(domains, assignments, xi, xj, trail)
161
+ constraints = binary_constraints[[xi, xj]]
162
+ predicates = binary_predicates[[xi, xj]]
163
+ constraints ||= []
164
+ predicates ||= []
165
+ return false if constraints.empty? && predicates.empty?
166
+
167
+ revised = false
168
+ domains[xi].dup.each do |x|
169
+ supported = domains[xj].any? do |y|
170
+ fast_ok = predicates_satisfied?(predicates, x, y)
171
+ next false unless fast_ok
172
+
173
+ with_assignment(assignments, xi, x) do
174
+ with_assignment(assignments, xj, y) do
175
+ constraints.all? { |constraint| constraint.validate(assignments) }
176
+ end
177
+ end
178
+ end
179
+ next if supported
180
+
181
+ remove_value(domains, xi, x, trail)
182
+ revised = true
183
+ end
184
+ revised
185
+ end
186
+
187
+ def all_arcs
188
+ arcs = []
189
+ neighbors.each do |xi, list|
190
+ list.each { |xj| arcs << [xi, xj] }
191
+ end
192
+ arcs
193
+ end
194
+
195
+ def arcs_from(var_name)
196
+ (neighbors[var_name] || []).map { |neighbor| [neighbor, var_name] }
197
+ end
198
+
199
+ def build_neighbors
200
+ map = Hash.new { |h, k| h[k] = Set.new }
201
+ csp.constraints.each do |constraint|
202
+ vars = constraint.variables
203
+ next unless vars.size == 2
204
+ a, b = vars
205
+ map[a] << b
206
+ map[b] << a
207
+ end
208
+ map.transform_values(&:to_a)
209
+ end
210
+
211
+ def build_binary_constraints
212
+ map = Hash.new { |h, k| h[k] = [] }
213
+ csp.constraints.each do |constraint|
214
+ vars = constraint.variables
215
+ next unless vars.size == 2
216
+ next if direct_predicate?(constraint)
217
+ a, b = vars
218
+ map[[a, b]] << constraint
219
+ map[[b, a]] << constraint
220
+ end
221
+ map
222
+ end
223
+
224
+ def build_binary_predicates
225
+ map = Hash.new { |h, k| h[k] = [] }
226
+ csp.constraints.each do |constraint|
227
+ vars = constraint.variables
228
+ next unless vars.size == 2
229
+ next unless direct_predicate?(constraint)
230
+ a, b = vars
231
+ predicate = constraint.predicate
232
+ map[[a, b]] << predicate
233
+ map[[b, a]] << ->(x, y) { predicate.call(y, x) }
234
+ end
235
+ map
236
+ end
237
+
238
+ def build_unary_constraints
239
+ map = Hash.new { |h, k| h[k] = [] }
240
+ csp.constraints.each do |constraint|
241
+ vars = constraint.variables
242
+ next unless vars.size == 1
243
+ map[vars.first] << constraint
244
+ end
245
+ map
246
+ end
247
+
248
+ def build_all_different_constraints
249
+ csp.constraints.select { |constraint| constraint.is_a?(Winston::Constraints::AllDifferent) }
250
+ end
251
+
252
+ def build_constraints_for_var
253
+ map = Hash.new { |h, k| h[k] = [] }
254
+ csp.variables.each_key { |name| map[name] = [] }
255
+
256
+ csp.constraints.each do |constraint|
257
+ if constraint.global
258
+ csp.variables.each_key { |name| map[name] << constraint }
259
+ else
260
+ constraint.variables.each { |name| map[name] << constraint }
261
+ end
262
+ end
263
+
264
+ map
265
+ end
266
+
267
+ def direct_predicate?(constraint)
268
+ constraint.is_a?(Winston::Constraint) &&
269
+ constraint.predicate &&
270
+ constraint.predicate.arity <= 2
271
+ end
272
+
273
+ def predicates_satisfied?(predicates, x, y)
274
+ predicates.all? { |predicate| predicate.call(x, y) }
275
+ end
276
+
277
+ def consistent_with_constraints(var_name, value, assignments)
278
+ with_assignment(assignments, var_name, value) do
279
+ constraints_for_var[var_name].all? do |constraint|
280
+ if constraint.global
281
+ assignments.size == csp.variables.size ? constraint.validate(assignments) : true
282
+ else
283
+ !constraint.elegible_for?(var_name, assignments) || constraint.validate(assignments)
284
+ end
285
+ end
286
+ end
287
+ end
288
+
289
+ def assign_domain(var_name, value, domains, trail)
290
+ return if domains[var_name].size == 1 && domains[var_name].first == value
291
+
292
+ trail << [:domain, var_name, domains[var_name]]
293
+ domains[var_name] = [value]
294
+ end
295
+
296
+ def propagate_all_different(domains, trail)
297
+ all_different_constraints.each do |constraint|
298
+ vars = constraint.variables
299
+ next if vars.empty?
300
+ return false unless all_different_propagator.gac(domains, trail, vars)
301
+ end
302
+ true
303
+ end
304
+
305
+ def with_assignment(assignments, var_name, value)
306
+ had = assignments.key?(var_name)
307
+ previous = assignments[var_name]
308
+ assignments[var_name] = value
309
+ yield
310
+ ensure
311
+ had ? assignments[var_name] = previous : assignments.delete(var_name)
312
+ end
313
+
314
+ def remove_value(domains, var_name, value, trail)
315
+ idx = domains[var_name].index(value)
316
+ return if idx.nil?
317
+
318
+ trail << [:value, var_name, value, idx]
319
+ domains[var_name].delete_at(idx)
320
+ end
321
+
322
+ def restore(domains, trail, mark)
323
+ while trail.length > mark
324
+ entry = trail.pop
325
+ if entry.first == :domain
326
+ _type, var_name, previous = entry
327
+ domains[var_name] = previous
328
+ else
329
+ _type, var_name, value, idx = entry
330
+ domains[var_name].insert(idx, value)
331
+ end
332
+ end
333
+ end
334
+
335
+ class AllDifferentPropagator
336
+ def gac(domains, trail, vars)
337
+ values = vars.flat_map { |name| domains[name] }.uniq
338
+ return false if values.empty?
339
+
340
+ value_index = {}
341
+ values.each_with_index { |value, idx| value_index[value] = idx }
342
+ adj = vars.map { |name| domains[name].map { |v| value_index[v] } }
343
+
344
+ matching = maximum_matching(adj, values.length)
345
+ return false if matching[:size] < vars.length
346
+
347
+ graph = build_regin_graph(adj, matching[:pair_u], values.length)
348
+ reachable = regin_reachable_from_free_values(graph, matching[:pair_v])
349
+ scc = tarjan_scc(graph)
350
+
351
+ vars.each_with_index do |name, var_idx|
352
+ domains[name].dup.each do |value|
353
+ val_idx = value_index[value]
354
+ next if matching[:pair_u][var_idx] == val_idx
355
+ if reachable[var_idx] || scc[var_idx] == scc[value_node(var_idx, val_idx, vars.length)]
356
+ next
357
+ end
358
+
359
+ remove_value(domains, name, value, trail)
360
+ return false if domains[name].empty?
361
+ end
362
+ end
363
+
364
+ true
365
+ end
366
+
367
+ private
368
+
369
+ def value_node(_var_idx, val_idx, var_count)
370
+ var_count + val_idx
371
+ end
372
+
373
+ def build_regin_graph(adj, pair_u, value_count)
374
+ var_count = adj.length
375
+ total_nodes = var_count + value_count
376
+ graph = Array.new(total_nodes) { [] }
377
+
378
+ adj.each_with_index do |values, var_idx|
379
+ values.each do |val_idx|
380
+ v_node = value_node(var_idx, val_idx, var_count)
381
+ if pair_u[var_idx] == val_idx
382
+ graph[v_node] << var_idx
383
+ else
384
+ graph[var_idx] << v_node
385
+ end
386
+ end
387
+ end
388
+
389
+ graph
390
+ end
391
+
392
+ def regin_reachable_from_free_values(graph, pair_v)
393
+ var_count = graph.length - pair_v.length
394
+ visited = Array.new(graph.length, false)
395
+ queue = []
396
+
397
+ pair_v.each_with_index do |matched_var, val_idx|
398
+ next unless matched_var == -1
399
+ node = var_count + val_idx
400
+ visited[node] = true
401
+ queue << node
402
+ end
403
+
404
+ idx = 0
405
+ while idx < queue.length
406
+ node = queue[idx]
407
+ idx += 1
408
+ graph[node].each do |nxt|
409
+ next if visited[nxt]
410
+ visited[nxt] = true
411
+ queue << nxt
412
+ end
413
+ end
414
+
415
+ visited
416
+ end
417
+
418
+ def tarjan_scc(graph)
419
+ index = 0
420
+ stack = []
421
+ on_stack = Array.new(graph.length, false)
422
+ indices = Array.new(graph.length, -1)
423
+ lowlinks = Array.new(graph.length, 0)
424
+ scc_id = Array.new(graph.length, -1)
425
+ current_scc = 0
426
+
427
+ strongconnect = lambda do |v|
428
+ indices[v] = index
429
+ lowlinks[v] = index
430
+ index += 1
431
+ stack << v
432
+ on_stack[v] = true
433
+
434
+ graph[v].each do |w|
435
+ if indices[w] == -1
436
+ strongconnect.call(w)
437
+ lowlinks[v] = [lowlinks[v], lowlinks[w]].min
438
+ elsif on_stack[w]
439
+ lowlinks[v] = [lowlinks[v], indices[w]].min
440
+ end
441
+ end
442
+
443
+ if lowlinks[v] == indices[v]
444
+ loop do
445
+ w = stack.pop
446
+ on_stack[w] = false
447
+ scc_id[w] = current_scc
448
+ break if w == v
449
+ end
450
+ current_scc += 1
451
+ end
452
+ end
453
+
454
+ graph.length.times do |v|
455
+ strongconnect.call(v) if indices[v] == -1
456
+ end
457
+
458
+ scc_id
459
+ end
460
+
461
+ def maximum_matching(adj, value_count)
462
+ n = adj.length
463
+ pair_u = Array.new(n, -1)
464
+ pair_v = Array.new(value_count, -1)
465
+ dist = Array.new(n)
466
+
467
+ bfs = lambda do
468
+ queue = []
469
+ n.times do |u|
470
+ if pair_u[u] == -1
471
+ dist[u] = 0
472
+ queue << u
473
+ else
474
+ dist[u] = nil
475
+ end
476
+ end
477
+
478
+ found = false
479
+ idx = 0
480
+ while idx < queue.length
481
+ u = queue[idx]
482
+ idx += 1
483
+ adj[u].each do |v|
484
+ pu = pair_v[v]
485
+ if pu != -1 && dist[pu].nil?
486
+ dist[pu] = dist[u] + 1
487
+ queue << pu
488
+ elsif pu == -1
489
+ found = true
490
+ end
491
+ end
492
+ end
493
+ found
494
+ end
495
+
496
+ dfs = lambda do |u|
497
+ adj[u].each do |v|
498
+ pu = pair_v[v]
499
+ if pu == -1 || (dist[pu] == dist[u] + 1 && dfs.call(pu))
500
+ pair_u[u] = v
501
+ pair_v[v] = u
502
+ return true
503
+ end
504
+ end
505
+ dist[u] = nil
506
+ false
507
+ end
508
+
509
+ while bfs.call
510
+ n.times do |u|
511
+ dfs.call(u) if pair_u[u] == -1
512
+ end
513
+ end
514
+
515
+ size = pair_u.count { |v| v != -1 }
516
+ { size: size, pair_u: pair_u, pair_v: pair_v }
517
+ end
518
+
519
+ def remove_value(domains, var_name, value, trail)
520
+ idx = domains[var_name].index(value)
521
+ return if idx.nil?
522
+
523
+ trail << [:value, var_name, value, idx]
524
+ domains[var_name].delete_at(idx)
525
+ end
526
+ end
527
+ end
528
+ end
529
+ end
@@ -0,0 +1,125 @@
1
+ module Winston
2
+ module Solvers
3
+ class MinConflicts
4
+ def initialize(csp, max_steps: 10_000, random: Random.new)
5
+ @csp = csp
6
+ @max_steps = max_steps
7
+ @random = random
8
+ @constraints_for_var = build_constraints_for_var
9
+ end
10
+
11
+ def search(assignments = {})
12
+ @fixed_names = fixed_variables(assignments)
13
+ preset = preset_assignments(assignments)
14
+ return false if preset.nil?
15
+
16
+ current = preset
17
+ @constraint_status = {}
18
+ @total_conflicts = 0
19
+ csp.constraints.each do |constraint|
20
+ valid = constraint.validate(current)
21
+ @constraint_status[constraint] = valid
22
+ @total_conflicts += 1 unless valid
23
+ end
24
+
25
+ max_steps.times do
26
+ return current if total_conflicts.zero?
27
+
28
+ conflicted = conflicted_variables(current)
29
+ return false if conflicted.empty?
30
+
31
+ var = conflicted[random.rand(conflicted.size)]
32
+ current[var.name] = best_value_for(var, current)
33
+ update_conflicts_for(var, current)
34
+ end
35
+
36
+ false
37
+ end
38
+
39
+ private
40
+
41
+ attr_reader :csp, :max_steps, :random, :fixed_names, :constraints_for_var
42
+
43
+ def preset_assignments(assignments)
44
+ preset = {}
45
+ csp.variables.each do |name, variable|
46
+ value = assignments.key?(name) ? assignments[name] : variable.value
47
+ if value.nil?
48
+ domain = csp.domain_for(name)
49
+ return nil if domain.empty?
50
+ value = domain[random.rand(domain.size)]
51
+ end
52
+ preset[name] = value
53
+ end
54
+ preset
55
+ end
56
+
57
+ def total_conflicts
58
+ @total_conflicts
59
+ end
60
+
61
+ def conflicted_variables(assignments)
62
+ csp.variables.values.reject do |var|
63
+ fixed?(var) || conflicts_for(var, assignments).zero?
64
+ end
65
+ end
66
+
67
+ def fixed?(var)
68
+ fixed_names.include?(var.name)
69
+ end
70
+
71
+ def conflicts_for(var, assignments)
72
+ constraints_for_var[var.name].count do |constraint|
73
+ !constraint.validate(assignments)
74
+ end
75
+ end
76
+
77
+ def best_value_for(var, assignments)
78
+ domain = csp.domain_for(var.name)
79
+ scored = domain.map do |value|
80
+ candidate = assignments.merge(var.name => value)
81
+ [value, conflicts_for(var, candidate)]
82
+ end
83
+ min = scored.map(&:last).min
84
+ best = scored.select { |(_, score)| score == min }.map(&:first)
85
+ best[random.rand(best.size)]
86
+ end
87
+
88
+ def fixed_variables(assignments)
89
+ fixed = assignments.select { |_name, value| !value.nil? }.keys
90
+ csp.variables.each_value do |var|
91
+ fixed << var.name unless var.value.nil?
92
+ end
93
+ fixed.uniq
94
+ end
95
+
96
+ def update_conflicts_for(var, assignments)
97
+ constraints_for_var[var.name].each do |constraint|
98
+ previous = @constraint_status[constraint]
99
+ current = constraint.validate(assignments)
100
+ next if previous == current
101
+
102
+ @constraint_status[constraint] = current
103
+ @total_conflicts += current ? -1 : 1
104
+ end
105
+ end
106
+
107
+ def build_constraints_for_var
108
+ map = Hash.new { |h, k| h[k] = [] }
109
+ csp.variables.each_key { |name| map[name] = [] }
110
+
111
+ csp.constraints.each do |constraint|
112
+ if constraint.global
113
+ csp.variables.each_key { |name| map[name] << constraint }
114
+ else
115
+ constraint.variables.each { |name| map[name] << constraint }
116
+ end
117
+ end
118
+
119
+ map
120
+ end
121
+ end
122
+ end
123
+
124
+ MinConflicts = Solvers::MinConflicts
125
+ end
data/lib/winston.rb CHANGED
@@ -1,8 +1,11 @@
1
- require 'winston/domain'
2
- require 'winston/variable'
3
1
  require 'winston/constraint'
4
- require 'winston/backtrack'
5
- require 'winston/csp'
6
-
7
2
  require 'winston/constraints/all_different'
8
3
  require 'winston/constraints/not_in_list'
4
+ require 'winston/csp'
5
+ require 'winston/domain'
6
+ require 'winston/dsl'
7
+ require 'winston/heuristics'
8
+ require 'winston/solvers/backtrack'
9
+ require 'winston/solvers/mac'
10
+ require 'winston/solvers/min_conflicts'
11
+ require 'winston/variable'
@@ -0,0 +1,48 @@
1
+ require "winston"
2
+
3
+ describe "Map coloring example" do
4
+ it "colors a simple map with three colors" do
5
+ csp = Winston.define do
6
+ domain :colors, %i[red green blue]
7
+
8
+ var :western_australia, domain: :colors
9
+ var :northern_territory, domain: :colors
10
+ var :south_australia, domain: :colors
11
+ var :queensland, domain: :colors
12
+ var :new_south_wales, domain: :colors
13
+ var :victoria, domain: :colors
14
+ var :tasmania, domain: :colors
15
+
16
+ constraint(:western_australia, :northern_territory) { |wa, nt| wa != nt }
17
+ constraint(:western_australia, :south_australia) { |wa, sa| wa != sa }
18
+ constraint(:northern_territory, :south_australia) { |nt, sa| nt != sa }
19
+ constraint(:northern_territory, :queensland) { |nt, q| nt != q }
20
+ constraint(:south_australia, :queensland) { |sa, q| sa != q }
21
+ constraint(:south_australia, :new_south_wales) { |sa, nsw| sa != nsw }
22
+ constraint(:south_australia, :victoria) { |sa, v| sa != v }
23
+ constraint(:queensland, :new_south_wales) { |q, nsw| q != nsw }
24
+ constraint(:new_south_wales, :victoria) { |nsw, v| nsw != v }
25
+ end
26
+
27
+ solution = csp.solve(:mac, variable_strategy: :mrv)
28
+
29
+ expect(solution).to include(
30
+ :western_australia,
31
+ :northern_territory,
32
+ :south_australia,
33
+ :queensland,
34
+ :new_south_wales,
35
+ :victoria,
36
+ :tasmania
37
+ )
38
+ expect(solution[:western_australia]).to_not eq(solution[:northern_territory])
39
+ expect(solution[:western_australia]).to_not eq(solution[:south_australia])
40
+ expect(solution[:northern_territory]).to_not eq(solution[:south_australia])
41
+ expect(solution[:northern_territory]).to_not eq(solution[:queensland])
42
+ expect(solution[:south_australia]).to_not eq(solution[:queensland])
43
+ expect(solution[:south_australia]).to_not eq(solution[:new_south_wales])
44
+ expect(solution[:south_australia]).to_not eq(solution[:victoria])
45
+ expect(solution[:queensland]).to_not eq(solution[:new_south_wales])
46
+ expect(solution[:new_south_wales]).to_not eq(solution[:victoria])
47
+ end
48
+ end