megahal 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,482 @@
1
+ require 'cld'
2
+ require 'sooth'
3
+ require 'tempfile'
4
+ require 'json'
5
+ require 'zip'
6
+
7
+ class MegaHAL
8
+ attr_accessor :learning
9
+
10
+ # Create a new MegaHAL instance, loading the :default personality.
11
+ def initialize
12
+ @learning = true
13
+ @seed = Sooth::Predictor.new(0)
14
+ @fore = Sooth::Predictor.new(0)
15
+ @back = Sooth::Predictor.new(0)
16
+ @case = Sooth::Predictor.new(0)
17
+ @punc = Sooth::Predictor.new(0)
18
+ become(:default)
19
+ end
20
+
21
+ def inspect
22
+ to_s
23
+ end
24
+
25
+ # Wipe MegaHAL's brain. Note that this wipes the personality too, allowing you
26
+ # to begin from a truly blank slate.
27
+ def clear
28
+ @seed.clear
29
+ @fore.clear
30
+ @back.clear
31
+ @case.clear
32
+ @punc.clear
33
+ @dictionary = { "<error>" => 0, "<fence>" => 1, "<blank>" => 2 }
34
+ nil
35
+ end
36
+
37
+ def self.add_personality(name, data)
38
+ @@personalities ||= {}
39
+ @@personalities[name.to_sym] = data.each_line.to_a
40
+ nil
41
+ end
42
+
43
+ # Returns an array of MegaHAL personalities.
44
+ #
45
+ # @return [Array] A list of symbols representing the available personalities.
46
+ def self.list
47
+ @@personalities ||= {}
48
+ @@personalities.keys
49
+ end
50
+
51
+ # Loads the specified personality. Will raise an exception if the personality
52
+ # parameter isn't one of those returned by #list. Note that this will clear
53
+ # MegaHAL's brain first.
54
+ #
55
+ # @param [Symbol] name The personality to be loaded.
56
+ def become(name=:default)
57
+ raise ArgumentError, "no such personality" unless @@personalities.key?(name)
58
+ clear
59
+ _train(@@personalities[name])
60
+ end
61
+
62
+ # Generate a reply to the user's input. If the learning attribute is set to true,
63
+ # MegaHAL will also learn from what the user said. Note that it takes MegaHAL
64
+ # about one second to generate about 500 replies.
65
+ #
66
+ # @param [String] input A string that represents the user's input. If this is
67
+ # nil, MegaHAL will attempt to reply with a greeting,
68
+ # suitable for beginning a conversation.
69
+ # @param [String] error The default reply, which will be used when no
70
+ # suitable reply can be formed.
71
+ #
72
+ # @return [String] MegaHAL's reply to the user's input, or the error
73
+ # string if no reply could be formed.
74
+ def reply(input, error="...")
75
+ puncs, norms, words = _decompose(input ? input.strip : nil)
76
+
77
+ keyword_symbols =
78
+ MegaHAL.extract(norms)
79
+ .map { |keyword| @dictionary[keyword] }
80
+ .compact
81
+
82
+ input_symbols = (norms || []).map { |norm| @dictionary[norm] }
83
+
84
+ # create candidate utterances
85
+ utterances = []
86
+ 9.times { utterances << _generate(keyword_symbols) }
87
+ utterances << _generate([])
88
+ utterances.delete_if { |utterance| utterance == input_symbols }
89
+ utterances.compact!
90
+
91
+ # select the best utterance, and handle _rewrite failure
92
+ reply = nil
93
+ while reply.nil? && utterances.length > 0
94
+ break unless utterance = _select_utterance(utterances, keyword_symbols)
95
+ reply = _rewrite(utterance)
96
+ utterances.delete(utterance)
97
+ end
98
+
99
+ # learn from what the user said _after_ generating the reply
100
+ _learn(puncs, norms, words) if @learning && norms
101
+
102
+ return reply || error
103
+ end
104
+
105
+ # Save MegaHAL's brain to the specified binary file.
106
+ #
107
+ # @param [String] filename The brain file to be saved.
108
+ # @param [ProgressBar] bar An optional progress bar instance.
109
+ def save(filename, bar = nil)
110
+ bar.total = 6 unless bar.nil?
111
+ Zip::File.open(filename, Zip::File::CREATE) do |zipfile|
112
+ zipfile.get_output_stream("dictionary") do |file|
113
+ file.write({
114
+ version: 'MH10',
115
+ learning: @learning,
116
+ dictionary: @dictionary
117
+ }.to_json)
118
+ end
119
+ bar.increment
120
+ [:seed, :fore, :back, :case, :punc].each do |name|
121
+ tmp = _get_tmp_filename(name)
122
+ instance_variable_get("@#{name}").save(tmp)
123
+ zipfile.add(name, tmp)
124
+ bar.increment
125
+ end
126
+ end
127
+ end
128
+
129
+ # Load a brain that has previously been saved.
130
+ #
131
+ # @param [String] filename The brain file to be loaded.
132
+ # @param [ProgressBar] bar An optional progress bar instance.
133
+ def load(filename, bar = nil)
134
+ bar.total = 6 unless bar.nil?
135
+ Zip::File.open(filename) do |zipfile|
136
+ data = JSON.parse(zipfile.find_entry("dictionary").get_input_stream.read)
137
+ raise "bad version" unless data['version'] == "MH10"
138
+ @learning = data['learning']
139
+ @dictionary = data['dictionary']
140
+ bar.increment
141
+ [:seed, :fore, :back, :case, :punc].each do |name|
142
+ tmp = _get_tmp_filename(name)
143
+ zipfile.find_entry(name.to_s).extract(tmp)
144
+ instance_variable_get("@#{name}").load(tmp)
145
+ bar.increment
146
+ end
147
+ end
148
+ end
149
+
150
+ # Train MegaHAL with the contents of the specified file, which should be plain
151
+ # text with one sentence per line. Note that it takes MegaHAL about one
152
+ # second to process about 500 lines, so large files may cause the process to
153
+ # block for a while. Lines that are too long will be skipped.
154
+ #
155
+ # @param [String] filename The text file to be used for training.
156
+ # @param [ProgressBar] bar An optional progress bar instance.
157
+ def train(filename, bar = nil)
158
+ lines = File.read(filename).each_line.to_a
159
+ bar.total = lines.length unless bar.nil?
160
+ _train(lines, bar)
161
+ end
162
+
163
+ private
164
+
165
+ def _train(data, bar = nil)
166
+ data.map!(&:strip)
167
+ data.each do |line|
168
+ _learn(*_decompose(line))
169
+ bar.increment unless bar.nil?
170
+ end
171
+ nil
172
+ end
173
+
174
+ # Train each of the five models based on a sentence decomposed into a list of
175
+ # word separators (puncs), capitalised words (norms) and words as they were
176
+ # observed (in mixed case).
177
+ def _learn(puncs, norms, words)
178
+ return if words.length == 0
179
+
180
+ # Convert the three lists of strings into three lists of symbols so that we
181
+ # can use the Sooth::Predictor. This is done by finding the ID of each of
182
+ # the strings in the @dictionary, allowing us to easily rewrite each symbol
183
+ # back to a string later.
184
+ punc_symbols = puncs.map { |punc| @dictionary[punc] ||= @dictionary.length }
185
+ norm_symbols = norms.map { |norm| @dictionary[norm] ||= @dictionary.length }
186
+ word_symbols = words.map { |word| @dictionary[word] ||= @dictionary.length }
187
+
188
+ # The @seed model is used to start the forwards-backwards reply generation.
189
+ # Given a keyword, we want to find a word that has been observed adjacent to
190
+ # it. Each context here is a bigram where one symbol is the keyword and the
191
+ # other is the special <blank> symbol (which has ID 2). The model learns
192
+ # which words can fill the blank.
193
+ prev = 1
194
+ (norm_symbols + [1]).each do |norm|
195
+ context = [prev, 2]
196
+ @seed.observe(context, norm)
197
+ context = [2, norm]
198
+ @seed.observe(context, prev)
199
+ prev = norm
200
+ end
201
+
202
+ # The @fore model is a classic second-order Markov model that can be used to
203
+ # generate an utterance in a random-walk fashion. For each adjacent pair of
204
+ # symbols the model learns which symbols can come next. Note that the
205
+ # special <fence> symbol (which has ID 1) is used to delimit the utterance.
206
+ context = [1, 1]
207
+ norm_symbols.each do |norm|
208
+ @fore.observe(context, norm)
209
+ context << norm
210
+ context.shift
211
+ end
212
+ @fore.observe(context, 1)
213
+
214
+ # The @back model is similar to the @fore model; it simply operates in the
215
+ # opposite direction. This is how the original MegaHAL was able to generate
216
+ # a random sentence guaranteed to contain a keyword; the @fore model filled
217
+ # in the gaps towards the end of the sentence, and the @back model filled in
218
+ # the gaps towards the beginning of the sentence.
219
+ context = [1, 1]
220
+ norm_symbols.reverse.each do |norm|
221
+ @back.observe(context, norm)
222
+ context << norm
223
+ context.shift
224
+ end
225
+ @back.observe(context, 1)
226
+
227
+ # The previous three models were all learning the sequence of norms, which
228
+ # are capitalised words. When we generate a reply, we want to rewrite it so
229
+ # MegaHAL doesn't speak in ALL CAPS. The @case model achieves this. For the
230
+ # previous word and the current norm it learns what the next word should be.
231
+ context = [1, 1]
232
+ word_symbols.zip(norm_symbols).each do |word, norm|
233
+ context[1] = norm
234
+ @case.observe(context, word)
235
+ context[0] = word
236
+ end
237
+
238
+ # After generating a list of words, we need to join them together with
239
+ # word-separators (whitespace and punctuation) in-between. The @punc model
240
+ # is used to do this; here it learns for two adjacent words which
241
+ # word-separators can be used to join them together.
242
+ context = [1, 1]
243
+ punc_symbols.zip(word_symbols + [1]).each do |punc, word|
244
+ context << word
245
+ context.shift
246
+ @punc.observe(context, punc)
247
+ end
248
+ end
249
+
250
+ # This takes a string and decomposes it into three arrays representing
251
+ # word-separators, capitalised words and the original words.
252
+ def _decompose(line, maximum_length=1024)
253
+ return [nil, nil, nil] if line.nil?
254
+ line = "" if line.length > maximum_length
255
+ return [[], [], []] if line.length == 0
256
+ puncs, words = _segment(line)
257
+ norms = words.map(&:upcase)
258
+ [puncs, norms, words]
259
+ end
260
+
261
+ # This segments a sentence into two arrays representing word-separators and
262
+ # the original words themselves/
263
+ def _segment(line)
264
+ # split the sentence into an array of alternating words and word-separators
265
+ sequence =
266
+ if _character_segmentation(line)
267
+ line.split(/([[:word:]])/)
268
+ else
269
+ line.split(/([[:word:]]+)/)
270
+ end
271
+ # ensure the array starts with and ends with a word-separator, even if it's the blank onw
272
+ sequence << "" if sequence.last =~ /[[:word:]]+/
273
+ sequence.unshift("") if sequence.first =~ /[[:word:]]+/
274
+ # join trigrams of word-separator-word if the separator is a single ' or -
275
+ # this means "don't" and "hob-goblin" become single words
276
+ while index = sequence[1..-2].index { |item| item =~ /^['-]$/ } do
277
+ sequence[index+1] = sequence[index, 3].join
278
+ sequence[index] = nil
279
+ sequence[index+2] = nil
280
+ sequence.compact!
281
+ end
282
+ # split the alternating sequence into two arrays of word-separators and words
283
+ sequence.partition.with_index { |symbol, index| index.even? }
284
+ end
285
+
286
+ # Given an array of keyword symbols, generate an array of norms that hopefully
287
+ # contain at least one of the keywords. All the symbols given as keywords must
288
+ # have been observed in the past, othewise this will raise an exception.
289
+ def _generate(keyword_symbols)
290
+ results =
291
+ if keyword = _select_keyword(keyword_symbols)
292
+ # Use the @seed model to find two contexts that contain the keyword.
293
+ contexts = [[2, keyword], [keyword, 2]]
294
+ contexts.map! do |context|
295
+ count = @seed.count(context)
296
+ if count > 0
297
+ limit = @seed.count(context)
298
+ context[context.index(2)] = @seed.select(context, limit)
299
+ context
300
+ else
301
+ nil
302
+ end
303
+ end
304
+ # Select one of the contexts at random
305
+ context = contexts.compact.shuffle.first
306
+ raise unless context
307
+ # Here we glue the generations of the @back and @fore models together
308
+ glue = context.select { |symbol| symbol != 1 }
309
+ _random_walk(@back, context.reverse, keyword_symbols).reverse + glue + _random_walk(@fore, context, keyword_symbols)
310
+ else
311
+ # we weren't given any keywords, so do a normal markovian generation
312
+ context = [1, 1]
313
+ _random_walk(@fore, context, keyword_symbols)
314
+ end
315
+ results.length == 0 ? nil : results
316
+ end
317
+
318
+ # Remove auxilliary words and select at random from what remains
319
+ def _select_keyword(keyword_symbols)
320
+ (keyword_symbols - AUXILIARY.map { |word| @dictionary[word] }).shuffle.first
321
+ end
322
+
323
+ # This is classic Markovian generation; using a model, start with a context
324
+ # and continue until we hit a <fence> symbol. The only addition here is that
325
+ # we roll the dice several times, and prefer generations that elicit a
326
+ # keyword.
327
+ def _random_walk(model, static_context, keyword_symbols)
328
+ context = static_context.dup
329
+ results = []
330
+ return [] if model.count(context) == 0
331
+ local_keywords = keyword_symbols.dup
332
+ loop do
333
+ symbol = 0
334
+ 10.times do
335
+ limit = rand(model.count(context)) + 1
336
+ symbol = model.select(context, limit)
337
+ if local_keywords.include?(symbol)
338
+ local_keywords.delete(symbol)
339
+ break
340
+ end
341
+ end
342
+ raise if symbol == 0
343
+ break if symbol == 1
344
+ results << symbol
345
+ context << symbol
346
+ context.shift
347
+ end
348
+ results
349
+ end
350
+
351
+ # Given an array of utterances and an array of keywords, select the best
352
+ # utterance (returning nil for none at all).
353
+ def _select_utterance(utterances, keyword_symbols)
354
+ best_score = -1
355
+ best_utterance = nil
356
+
357
+ utterances.each do |utterance|
358
+ score = _calculate_score(utterance, keyword_symbols)
359
+ next unless score > best_score
360
+ best_score = score
361
+ best_utterance = utterance
362
+ end
363
+
364
+ return best_utterance
365
+ end
366
+
367
+ # Calculate the score of a particular utterance
368
+ def _calculate_score(utterance, keyword_symbols)
369
+ score = 0
370
+
371
+ context = [1, 1]
372
+ utterance.each do |norm|
373
+ if keyword_symbols.include?(norm)
374
+ surprise = @fore.surprise(context, norm)
375
+ score += surprise unless surprise.nil?
376
+ end
377
+ context << norm
378
+ context.shift
379
+ end
380
+
381
+ context = [1, 1]
382
+ utterance.reverse.each do |norm|
383
+ if keyword_symbols.include?(norm)
384
+ surprise = @back.surprise(context, norm)
385
+ score += surprise unless surprise.nil?
386
+ end
387
+ context << norm
388
+ context.shift
389
+ end
390
+
391
+ if utterance.length >= 8
392
+ score /= Math.sqrt(utterance.length - 1)
393
+ end
394
+
395
+ if utterance.length >= 16
396
+ score /= utterance.length
397
+ end
398
+
399
+ score
400
+ end
401
+
402
+ # Here we take a generated sequence of norms and convert them back to a string
403
+ # that may be displayed to the user as output. This involves using the @case
404
+ # model to rewrite each norm as a word, and then using the @punc model to
405
+ # insert appropriate word separators.
406
+ def _rewrite(norm_symbols)
407
+ decode = Hash[@dictionary.to_a.map(&:reverse)]
408
+
409
+ # Here we generate the sequence of words. This is slightly tricky, because
410
+ # it is possible to generate a word (based on the context of the previous
411
+ # word and the current norm) such that it is impossible to generate the next
412
+ # word in the sequence (because we may generate a word of a different case
413
+ # than what we have observed in the past). So we keep trying until we
414
+ # stumble upon a combination that works, or until we've tried too many
415
+ # times. Note that backtracking would need to go back an arbitrary number of
416
+ # steps, and is therefore too messy to implement.
417
+ word_symbols = []
418
+ context = [1, 1]
419
+ i = 0
420
+ retries = 0
421
+ while word_symbols.length != norm_symbols.length
422
+ return nil if retries > 9
423
+ # We're trying to rewrite norms to words, so build a context for the @case
424
+ # model, of the previous word and the current norm. This may fail if the
425
+ # previous word hasn't been observed adjacent to the current norm, which
426
+ # will happen if the rewrote the previous norm to a different case that
427
+ # what was observed previously.
428
+ context[0] = (i == 0) ? 1 : word_symbols[i-1]
429
+ context[1] = norm_symbols[i]
430
+ count = @case.count(context)
431
+ unless failed = (count == 0)
432
+ limit = rand(count) + 1
433
+ word_symbols << @case.select(context, limit)
434
+ end
435
+ if (word_symbols.length == norm_symbols.length)
436
+ # We need to check that the final word has been previously observed.
437
+ context[0] = word_symbols.last
438
+ context[1] = 1
439
+ failed = (@punc.count(context) == 0)
440
+ end
441
+ if failed
442
+ raise if i == 0
443
+ retries += 1
444
+ word_symbols.clear
445
+ i = 0
446
+ next
447
+ end
448
+ i += 1
449
+ end
450
+
451
+ # We've used the case model to rewrite the norms to a words in a way that
452
+ # guarantees that each adjacent pair of words has been previously observed.
453
+ # Now we use the @punc model to generate the word-separators to be inserted
454
+ # between the words in the reply.
455
+ punc_symbols = []
456
+ context = [1, 1]
457
+ (word_symbols + [1]).each do |word|
458
+ context << word
459
+ context.shift
460
+ limit = rand(@punc.count(context)) + 1
461
+ punc_symbols << @punc.select(context, limit)
462
+ end
463
+
464
+ # Finally we zip the word-separators and the words together, decode the
465
+ # symbols to their string representations (as stored in the @dictionary),
466
+ # and join everything together to give the final reply.
467
+ punc_symbols.zip(word_symbols).flatten.map { |word| decode[word] }.join
468
+ end
469
+
470
+ def _get_tmp_filename(name)
471
+ file = Tempfile.new(name.to_s)
472
+ retval = file.path
473
+ file.close
474
+ file.unlink
475
+ return retval
476
+ end
477
+
478
+ def _character_segmentation(line)
479
+ language = CLD.detect_language(line)[:name]
480
+ ["Japanese", "Korean", "Chinese", "TG_UNKNOWN_LANGUAGE", "Unknown", "JAVANESE", "THAI", "ChineseT", "LAOTHIAN", "BURMESE", "KHMER", "XX"].include?(language)
481
+ end
482
+ end