megahal 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,482 @@
1
+ require 'cld'
2
+ require 'sooth'
3
+ require 'tempfile'
4
+ require 'json'
5
+ require 'zip'
6
+
7
+ class MegaHAL
8
+ attr_accessor :learning
9
+
10
+ # Create a new MegaHAL instance, loading the :default personality.
11
+ def initialize
12
+ @learning = true
13
+ @seed = Sooth::Predictor.new(0)
14
+ @fore = Sooth::Predictor.new(0)
15
+ @back = Sooth::Predictor.new(0)
16
+ @case = Sooth::Predictor.new(0)
17
+ @punc = Sooth::Predictor.new(0)
18
+ become(:default)
19
+ end
20
+
21
+ def inspect
22
+ to_s
23
+ end
24
+
25
+ # Wipe MegaHAL's brain. Note that this wipes the personality too, allowing you
26
+ # to begin from a truly blank slate.
27
+ def clear
28
+ @seed.clear
29
+ @fore.clear
30
+ @back.clear
31
+ @case.clear
32
+ @punc.clear
33
+ @dictionary = { "<error>" => 0, "<fence>" => 1, "<blank>" => 2 }
34
+ nil
35
+ end
36
+
37
+ def self.add_personality(name, data)
38
+ @@personalities ||= {}
39
+ @@personalities[name.to_sym] = data.each_line.to_a
40
+ nil
41
+ end
42
+
43
+ # Returns an array of MegaHAL personalities.
44
+ #
45
+ # @return [Array] A list of symbols representing the available personalities.
46
+ def self.list
47
+ @@personalities ||= {}
48
+ @@personalities.keys
49
+ end
50
+
51
+ # Loads the specified personality. Will raise an exception if the personality
52
+ # parameter isn't one of those returned by #list. Note that this will clear
53
+ # MegaHAL's brain first.
54
+ #
55
+ # @param [Symbol] name The personality to be loaded.
56
+ def become(name=:default)
57
+ raise ArgumentError, "no such personality" unless @@personalities.key?(name)
58
+ clear
59
+ _train(@@personalities[name])
60
+ end
61
+
62
+ # Generate a reply to the user's input. If the learning attribute is set to true,
63
+ # MegaHAL will also learn from what the user said. Note that it takes MegaHAL
64
+ # about one second to generate about 500 replies.
65
+ #
66
+ # @param [String] input A string that represents the user's input. If this is
67
+ # nil, MegaHAL will attempt to reply with a greeting,
68
+ # suitable for beginning a conversation.
69
+ # @param [String] error The default reply, which will be used when no
70
+ # suitable reply can be formed.
71
+ #
72
+ # @return [String] MegaHAL's reply to the user's input, or the error
73
+ # string if no reply could be formed.
74
+ def reply(input, error="...")
75
+ puncs, norms, words = _decompose(input ? input.strip : nil)
76
+
77
+ keyword_symbols =
78
+ MegaHAL.extract(norms)
79
+ .map { |keyword| @dictionary[keyword] }
80
+ .compact
81
+
82
+ input_symbols = (norms || []).map { |norm| @dictionary[norm] }
83
+
84
+ # create candidate utterances
85
+ utterances = []
86
+ 9.times { utterances << _generate(keyword_symbols) }
87
+ utterances << _generate([])
88
+ utterances.delete_if { |utterance| utterance == input_symbols }
89
+ utterances.compact!
90
+
91
+ # select the best utterance, and handle _rewrite failure
92
+ reply = nil
93
+ while reply.nil? && utterances.length > 0
94
+ break unless utterance = _select_utterance(utterances, keyword_symbols)
95
+ reply = _rewrite(utterance)
96
+ utterances.delete(utterance)
97
+ end
98
+
99
+ # learn from what the user said _after_ generating the reply
100
+ _learn(puncs, norms, words) if @learning && norms
101
+
102
+ return reply || error
103
+ end
104
+
105
+ # Save MegaHAL's brain to the specified binary file.
106
+ #
107
+ # @param [String] filename The brain file to be saved.
108
+ # @param [ProgressBar] bar An optional progress bar instance.
109
+ def save(filename, bar = nil)
110
+ bar.total = 6 unless bar.nil?
111
+ Zip::File.open(filename, Zip::File::CREATE) do |zipfile|
112
+ zipfile.get_output_stream("dictionary") do |file|
113
+ file.write({
114
+ version: 'MH10',
115
+ learning: @learning,
116
+ dictionary: @dictionary
117
+ }.to_json)
118
+ end
119
+ bar.increment
120
+ [:seed, :fore, :back, :case, :punc].each do |name|
121
+ tmp = _get_tmp_filename(name)
122
+ instance_variable_get("@#{name}").save(tmp)
123
+ zipfile.add(name, tmp)
124
+ bar.increment
125
+ end
126
+ end
127
+ end
128
+
129
+ # Load a brain that has previously been saved.
130
+ #
131
+ # @param [String] filename The brain file to be loaded.
132
+ # @param [ProgressBar] bar An optional progress bar instance.
133
+ def load(filename, bar = nil)
134
+ bar.total = 6 unless bar.nil?
135
+ Zip::File.open(filename) do |zipfile|
136
+ data = JSON.parse(zipfile.find_entry("dictionary").get_input_stream.read)
137
+ raise "bad version" unless data['version'] == "MH10"
138
+ @learning = data['learning']
139
+ @dictionary = data['dictionary']
140
+ bar.increment
141
+ [:seed, :fore, :back, :case, :punc].each do |name|
142
+ tmp = _get_tmp_filename(name)
143
+ zipfile.find_entry(name.to_s).extract(tmp)
144
+ instance_variable_get("@#{name}").load(tmp)
145
+ bar.increment
146
+ end
147
+ end
148
+ end
149
+
150
+ # Train MegaHAL with the contents of the specified file, which should be plain
151
+ # text with one sentence per line. Note that it takes MegaHAL about one
152
+ # second to process about 500 lines, so large files may cause the process to
153
+ # block for a while. Lines that are too long will be skipped.
154
+ #
155
+ # @param [String] filename The text file to be used for training.
156
+ # @param [ProgressBar] bar An optional progress bar instance.
157
+ def train(filename, bar = nil)
158
+ lines = File.read(filename).each_line.to_a
159
+ bar.total = lines.length unless bar.nil?
160
+ _train(lines, bar)
161
+ end
162
+
163
+ private
164
+
165
+ def _train(data, bar = nil)
166
+ data.map!(&:strip)
167
+ data.each do |line|
168
+ _learn(*_decompose(line))
169
+ bar.increment unless bar.nil?
170
+ end
171
+ nil
172
+ end
173
+
174
+ # Train each of the five models based on a sentence decomposed into a list of
175
+ # word separators (puncs), capitalised words (norms) and words as they were
176
+ # observed (in mixed case).
177
+ def _learn(puncs, norms, words)
178
+ return if words.length == 0
179
+
180
+ # Convert the three lists of strings into three lists of symbols so that we
181
+ # can use the Sooth::Predictor. This is done by finding the ID of each of
182
+ # the strings in the @dictionary, allowing us to easily rewrite each symbol
183
+ # back to a string later.
184
+ punc_symbols = puncs.map { |punc| @dictionary[punc] ||= @dictionary.length }
185
+ norm_symbols = norms.map { |norm| @dictionary[norm] ||= @dictionary.length }
186
+ word_symbols = words.map { |word| @dictionary[word] ||= @dictionary.length }
187
+
188
+ # The @seed model is used to start the forwards-backwards reply generation.
189
+ # Given a keyword, we want to find a word that has been observed adjacent to
190
+ # it. Each context here is a bigram where one symbol is the keyword and the
191
+ # other is the special <blank> symbol (which has ID 2). The model learns
192
+ # which words can fill the blank.
193
+ prev = 1
194
+ (norm_symbols + [1]).each do |norm|
195
+ context = [prev, 2]
196
+ @seed.observe(context, norm)
197
+ context = [2, norm]
198
+ @seed.observe(context, prev)
199
+ prev = norm
200
+ end
201
+
202
+ # The @fore model is a classic second-order Markov model that can be used to
203
+ # generate an utterance in a random-walk fashion. For each adjacent pair of
204
+ # symbols the model learns which symbols can come next. Note that the
205
+ # special <fence> symbol (which has ID 1) is used to delimit the utterance.
206
+ context = [1, 1]
207
+ norm_symbols.each do |norm|
208
+ @fore.observe(context, norm)
209
+ context << norm
210
+ context.shift
211
+ end
212
+ @fore.observe(context, 1)
213
+
214
+ # The @back model is similar to the @fore model; it simply operates in the
215
+ # opposite direction. This is how the original MegaHAL was able to generate
216
+ # a random sentence guaranteed to contain a keyword; the @fore model filled
217
+ # in the gaps towards the end of the sentence, and the @back model filled in
218
+ # the gaps towards the beginning of the sentence.
219
+ context = [1, 1]
220
+ norm_symbols.reverse.each do |norm|
221
+ @back.observe(context, norm)
222
+ context << norm
223
+ context.shift
224
+ end
225
+ @back.observe(context, 1)
226
+
227
+ # The previous three models were all learning the sequence of norms, which
228
+ # are capitalised words. When we generate a reply, we want to rewrite it so
229
+ # MegaHAL doesn't speak in ALL CAPS. The @case model achieves this. For the
230
+ # previous word and the current norm it learns what the next word should be.
231
+ context = [1, 1]
232
+ word_symbols.zip(norm_symbols).each do |word, norm|
233
+ context[1] = norm
234
+ @case.observe(context, word)
235
+ context[0] = word
236
+ end
237
+
238
+ # After generating a list of words, we need to join them together with
239
+ # word-separators (whitespace and punctuation) in-between. The @punc model
240
+ # is used to do this; here it learns for two adjacent words which
241
+ # word-separators can be used to join them together.
242
+ context = [1, 1]
243
+ punc_symbols.zip(word_symbols + [1]).each do |punc, word|
244
+ context << word
245
+ context.shift
246
+ @punc.observe(context, punc)
247
+ end
248
+ end
249
+
250
+ # This takes a string and decomposes it into three arrays representing
251
+ # word-separators, capitalised words and the original words.
252
+ def _decompose(line, maximum_length=1024)
253
+ return [nil, nil, nil] if line.nil?
254
+ line = "" if line.length > maximum_length
255
+ return [[], [], []] if line.length == 0
256
+ puncs, words = _segment(line)
257
+ norms = words.map(&:upcase)
258
+ [puncs, norms, words]
259
+ end
260
+
261
+ # This segments a sentence into two arrays representing word-separators and
262
+ # the original words themselves/
263
+ def _segment(line)
264
+ # split the sentence into an array of alternating words and word-separators
265
+ sequence =
266
+ if _character_segmentation(line)
267
+ line.split(/([[:word:]])/)
268
+ else
269
+ line.split(/([[:word:]]+)/)
270
+ end
271
+ # ensure the array starts with and ends with a word-separator, even if it's the blank onw
272
+ sequence << "" if sequence.last =~ /[[:word:]]+/
273
+ sequence.unshift("") if sequence.first =~ /[[:word:]]+/
274
+ # join trigrams of word-separator-word if the separator is a single ' or -
275
+ # this means "don't" and "hob-goblin" become single words
276
+ while index = sequence[1..-2].index { |item| item =~ /^['-]$/ } do
277
+ sequence[index+1] = sequence[index, 3].join
278
+ sequence[index] = nil
279
+ sequence[index+2] = nil
280
+ sequence.compact!
281
+ end
282
+ # split the alternating sequence into two arrays of word-separators and words
283
+ sequence.partition.with_index { |symbol, index| index.even? }
284
+ end
285
+
286
+ # Given an array of keyword symbols, generate an array of norms that hopefully
287
+ # contain at least one of the keywords. All the symbols given as keywords must
288
+ # have been observed in the past, othewise this will raise an exception.
289
+ def _generate(keyword_symbols)
290
+ results =
291
+ if keyword = _select_keyword(keyword_symbols)
292
+ # Use the @seed model to find two contexts that contain the keyword.
293
+ contexts = [[2, keyword], [keyword, 2]]
294
+ contexts.map! do |context|
295
+ count = @seed.count(context)
296
+ if count > 0
297
+ limit = @seed.count(context)
298
+ context[context.index(2)] = @seed.select(context, limit)
299
+ context
300
+ else
301
+ nil
302
+ end
303
+ end
304
+ # Select one of the contexts at random
305
+ context = contexts.compact.shuffle.first
306
+ raise unless context
307
+ # Here we glue the generations of the @back and @fore models together
308
+ glue = context.select { |symbol| symbol != 1 }
309
+ _random_walk(@back, context.reverse, keyword_symbols).reverse + glue + _random_walk(@fore, context, keyword_symbols)
310
+ else
311
+ # we weren't given any keywords, so do a normal markovian generation
312
+ context = [1, 1]
313
+ _random_walk(@fore, context, keyword_symbols)
314
+ end
315
+ results.length == 0 ? nil : results
316
+ end
317
+
318
+ # Remove auxilliary words and select at random from what remains
319
+ def _select_keyword(keyword_symbols)
320
+ (keyword_symbols - AUXILIARY.map { |word| @dictionary[word] }).shuffle.first
321
+ end
322
+
323
+ # This is classic Markovian generation; using a model, start with a context
324
+ # and continue until we hit a <fence> symbol. The only addition here is that
325
+ # we roll the dice several times, and prefer generations that elicit a
326
+ # keyword.
327
+ def _random_walk(model, static_context, keyword_symbols)
328
+ context = static_context.dup
329
+ results = []
330
+ return [] if model.count(context) == 0
331
+ local_keywords = keyword_symbols.dup
332
+ loop do
333
+ symbol = 0
334
+ 10.times do
335
+ limit = rand(model.count(context)) + 1
336
+ symbol = model.select(context, limit)
337
+ if local_keywords.include?(symbol)
338
+ local_keywords.delete(symbol)
339
+ break
340
+ end
341
+ end
342
+ raise if symbol == 0
343
+ break if symbol == 1
344
+ results << symbol
345
+ context << symbol
346
+ context.shift
347
+ end
348
+ results
349
+ end
350
+
351
+ # Given an array of utterances and an array of keywords, select the best
352
+ # utterance (returning nil for none at all).
353
+ def _select_utterance(utterances, keyword_symbols)
354
+ best_score = -1
355
+ best_utterance = nil
356
+
357
+ utterances.each do |utterance|
358
+ score = _calculate_score(utterance, keyword_symbols)
359
+ next unless score > best_score
360
+ best_score = score
361
+ best_utterance = utterance
362
+ end
363
+
364
+ return best_utterance
365
+ end
366
+
367
+ # Calculate the score of a particular utterance
368
+ def _calculate_score(utterance, keyword_symbols)
369
+ score = 0
370
+
371
+ context = [1, 1]
372
+ utterance.each do |norm|
373
+ if keyword_symbols.include?(norm)
374
+ surprise = @fore.surprise(context, norm)
375
+ score += surprise unless surprise.nil?
376
+ end
377
+ context << norm
378
+ context.shift
379
+ end
380
+
381
+ context = [1, 1]
382
+ utterance.reverse.each do |norm|
383
+ if keyword_symbols.include?(norm)
384
+ surprise = @back.surprise(context, norm)
385
+ score += surprise unless surprise.nil?
386
+ end
387
+ context << norm
388
+ context.shift
389
+ end
390
+
391
+ if utterance.length >= 8
392
+ score /= Math.sqrt(utterance.length - 1)
393
+ end
394
+
395
+ if utterance.length >= 16
396
+ score /= utterance.length
397
+ end
398
+
399
+ score
400
+ end
401
+
402
+ # Here we take a generated sequence of norms and convert them back to a string
403
+ # that may be displayed to the user as output. This involves using the @case
404
+ # model to rewrite each norm as a word, and then using the @punc model to
405
+ # insert appropriate word separators.
406
+ def _rewrite(norm_symbols)
407
+ decode = Hash[@dictionary.to_a.map(&:reverse)]
408
+
409
+ # Here we generate the sequence of words. This is slightly tricky, because
410
+ # it is possible to generate a word (based on the context of the previous
411
+ # word and the current norm) such that it is impossible to generate the next
412
+ # word in the sequence (because we may generate a word of a different case
413
+ # than what we have observed in the past). So we keep trying until we
414
+ # stumble upon a combination that works, or until we've tried too many
415
+ # times. Note that backtracking would need to go back an arbitrary number of
416
+ # steps, and is therefore too messy to implement.
417
+ word_symbols = []
418
+ context = [1, 1]
419
+ i = 0
420
+ retries = 0
421
+ while word_symbols.length != norm_symbols.length
422
+ return nil if retries > 9
423
+ # We're trying to rewrite norms to words, so build a context for the @case
424
+ # model, of the previous word and the current norm. This may fail if the
425
+ # previous word hasn't been observed adjacent to the current norm, which
426
+ # will happen if the rewrote the previous norm to a different case that
427
+ # what was observed previously.
428
+ context[0] = (i == 0) ? 1 : word_symbols[i-1]
429
+ context[1] = norm_symbols[i]
430
+ count = @case.count(context)
431
+ unless failed = (count == 0)
432
+ limit = rand(count) + 1
433
+ word_symbols << @case.select(context, limit)
434
+ end
435
+ if (word_symbols.length == norm_symbols.length)
436
+ # We need to check that the final word has been previously observed.
437
+ context[0] = word_symbols.last
438
+ context[1] = 1
439
+ failed = (@punc.count(context) == 0)
440
+ end
441
+ if failed
442
+ raise if i == 0
443
+ retries += 1
444
+ word_symbols.clear
445
+ i = 0
446
+ next
447
+ end
448
+ i += 1
449
+ end
450
+
451
+ # We've used the case model to rewrite the norms to a words in a way that
452
+ # guarantees that each adjacent pair of words has been previously observed.
453
+ # Now we use the @punc model to generate the word-separators to be inserted
454
+ # between the words in the reply.
455
+ punc_symbols = []
456
+ context = [1, 1]
457
+ (word_symbols + [1]).each do |word|
458
+ context << word
459
+ context.shift
460
+ limit = rand(@punc.count(context)) + 1
461
+ punc_symbols << @punc.select(context, limit)
462
+ end
463
+
464
+ # Finally we zip the word-separators and the words together, decode the
465
+ # symbols to their string representations (as stored in the @dictionary),
466
+ # and join everything together to give the final reply.
467
+ punc_symbols.zip(word_symbols).flatten.map { |word| decode[word] }.join
468
+ end
469
+
470
+ def _get_tmp_filename(name)
471
+ file = Tempfile.new(name.to_s)
472
+ retval = file.path
473
+ file.close
474
+ file.unlink
475
+ return retval
476
+ end
477
+
478
+ def _character_segmentation(line)
479
+ language = CLD.detect_language(line)[:name]
480
+ ["Japanese", "Korean", "Chinese", "TG_UNKNOWN_LANGUAGE", "Unknown", "JAVANESE", "THAI", "ChineseT", "LAOTHIAN", "BURMESE", "KHMER", "XX"].include?(language)
481
+ end
482
+ end