ChatterBot 1.2.3__py3-none-any.whl → 1.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -110,6 +110,9 @@ class DjangoStorageAdapter(StorageAdapter):
110
110
 
111
111
  tags = kwargs.pop('tags', [])
112
112
 
113
+ if 'search_in_response_to' in kwargs and kwargs['search_in_response_to'] is None:
114
+ kwargs['search_in_response_to'] = ''
115
+
113
116
  statement = Statement(**kwargs)
114
117
 
115
118
  statement.save()
@@ -169,7 +172,7 @@ class DjangoStorageAdapter(StorageAdapter):
169
172
  search_text=statement.search_text,
170
173
  conversation=statement.conversation,
171
174
  in_response_to=statement.in_response_to,
172
- search_in_response_to=statement.search_in_response_to,
175
+ search_in_response_to=statement.search_in_response_to or '',
173
176
  created_at=statement.created_at
174
177
  )
175
178
 
@@ -326,7 +326,7 @@ class SQLStorageAdapter(StorageAdapter):
326
326
  record = None
327
327
 
328
328
  if hasattr(statement, 'id') and statement.id is not None:
329
- record = session.query(Statement).get(statement.id)
329
+ record = session.get(Statement, statement.id)
330
330
  else:
331
331
  record = session.query(Statement).filter(
332
332
  Statement.text == statement.text,
chatterbot/tagging.py CHANGED
@@ -1,4 +1,5 @@
1
- from chatterbot import languages, constants
1
+ from chatterbot import languages
2
+ from chatterbot.utils import get_model_for_language
2
3
  import spacy
3
4
 
4
5
 
@@ -42,12 +43,7 @@ class PosLemmaTagger(object):
42
43
 
43
44
  self.language = language or languages.ENG
44
45
 
45
- try:
46
- model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[self.language]
47
- except KeyError as e:
48
- raise KeyError(
49
- f'Spacy model is not available for language {self.language}'
50
- ) from e
46
+ model = get_model_for_language(self.language)
51
47
 
52
48
  # Disable the Named Entity Recognition (NER) component because it is not necessary
53
49
  self.nlp = spacy.load(model, exclude=['ner'])
chatterbot/trainers.py CHANGED
@@ -25,7 +25,7 @@ class Trainer(object):
25
25
 
26
26
  environment_default = bool(int(os.environ.get('CHATTERBOT_SHOW_TRAINING_PROGRESS', True)))
27
27
 
28
- self.show_training_progress = kwargs.get(
28
+ self.disable_progress = not kwargs.get(
29
29
  'show_training_progress',
30
30
  environment_default
31
31
  )
@@ -54,7 +54,7 @@ class Trainer(object):
54
54
  def __init__(self, message=None):
55
55
  default = (
56
56
  'A training class must be specified before calling train(). '
57
- 'See https://docs.chatterbot.us/training.html'
57
+ 'See https://docs.chatterbot.us/training/'
58
58
  )
59
59
  super().__init__(message or default)
60
60
 
@@ -82,7 +82,7 @@ class ListTrainer(Trainer):
82
82
  where the list represents a conversation.
83
83
  """
84
84
 
85
- def train(self, conversation):
85
+ def train(self, conversation: list):
86
86
  """
87
87
  Train the chat bot based on the provided list of
88
88
  statements that represents a single conversation.
@@ -96,7 +96,7 @@ class ListTrainer(Trainer):
96
96
  documents = self.chatbot.tagger.as_nlp_pipeline(conversation)
97
97
 
98
98
  # for text in enumerate(conversation):
99
- for document in tqdm(documents, desc='List Trainer', disable=not self.show_training_progress):
99
+ for document in tqdm(documents, desc='List Trainer', disable=self.disable_progress):
100
100
  statement_search_text = document._.search_index
101
101
 
102
102
  statement = self.get_preprocessed_statement(
@@ -135,7 +135,7 @@ class ChatterBotCorpusTrainer(Trainer):
135
135
  for corpus, categories, _file_path in tqdm(
136
136
  load_corpus(*data_file_paths),
137
137
  desc='ChatterBot Corpus Trainer',
138
- disable=not self.show_training_progress
138
+ disable=self.disable_progress
139
139
  ):
140
140
  statements_to_create = []
141
141
 
@@ -172,32 +172,268 @@ class ChatterBotCorpusTrainer(Trainer):
172
172
  self.chatbot.storage.create_many(statements_to_create)
173
173
 
174
174
 
175
- class UbuntuCorpusTrainer(Trainer):
175
+ class GenericFileTrainer(Trainer):
176
176
  """
177
+ Allows the chat bot to be trained using data from a CSV or JSON file,
178
+ or directory of those file types.
179
+ """
180
+
181
+ def __init__(self, chatbot, **kwargs):
182
+ """
183
+ data_path: str The path to the data file or directory.
184
+ field_map: dict A dictionary containing the column name to header mapping.
185
+ """
186
+ super().__init__(chatbot, **kwargs)
187
+
188
+ self.file_extension = None
189
+
190
+ # NOTE: If the key is an integer, this be the
191
+ # column index instead of the key or header
192
+ DEFAULT_STATEMENT_TO_HEADER_MAPPING = {
193
+ 'text': 'text',
194
+ 'conversation': 'conversation',
195
+ 'created_at': 'created_at',
196
+ 'persona': 'persona',
197
+ 'tags': 'tags'
198
+ }
199
+
200
+ self.field_map = kwargs.get(
201
+ 'field_map',
202
+ DEFAULT_STATEMENT_TO_HEADER_MAPPING
203
+ )
204
+
205
+ def _get_file_list(self, data_path, limit):
206
+ """
207
+ Get a list of files to read from the data set.
208
+ """
209
+
210
+ if self.file_extension is None:
211
+ raise self.TrainerInitializationException(
212
+ 'The file_extension attribute must be set before calling train().'
213
+ )
214
+
215
+ # List all csv or json files in the specified directory
216
+ if os.path.isdir(data_path):
217
+ glob_path = os.path.join(data_path, '**', f'*.{self.file_extension}')
218
+
219
+ # Use iglob instead of glob for better performance with
220
+ # large directories because it returns an iterator
221
+ data_files = glob.iglob(glob_path, recursive=True)
222
+
223
+ for index, file_path in enumerate(data_files):
224
+ if limit is not None and index >= limit:
225
+ break
226
+
227
+ yield file_path
228
+ else:
229
+ yield data_path
230
+
231
+ def train(self, data_path: str, limit=None):
232
+ """
233
+ Train a chatbot with data from the data file.
234
+
235
+ :param str data_path: The path to the data file or directory.
236
+ :param int limit: The maximum number of files to train from.
237
+ """
238
+
239
+ if data_path is None:
240
+ raise self.TrainerInitializationException(
241
+ 'The data_path argument must be set to the path of a file or directory.'
242
+ )
243
+
244
+ data_files = self._get_file_list(data_path, limit)
245
+
246
+ files_processed = 0
247
+
248
+ for data_file in tqdm(data_files, desc='Training', disable=self.disable_progress):
249
+
250
+ previous_statement_text = None
251
+ previous_statement_search_text = ''
252
+
253
+ file_extension = data_file.split('.')[-1].lower()
254
+
255
+ statements_to_create = []
256
+
257
+ file_abspath = os.path.abspath(data_file)
258
+
259
+ with open(file_abspath, 'r', encoding='utf-8') as file:
260
+
261
+ if self.file_extension == 'json':
262
+ data = json.load(file)
263
+ data = data['conversation']
264
+ elif file_extension == 'csv':
265
+ use_header = bool(isinstance(next(iter(self.field_map.values())), str))
266
+
267
+ if use_header:
268
+ data = csv.DictReader(file)
269
+ else:
270
+ data = csv.reader(file)
271
+ elif file_extension == 'tsv':
272
+ use_header = bool(isinstance(next(iter(self.field_map.values())), str))
273
+
274
+ if use_header:
275
+ data = csv.DictReader(file, delimiter='\t')
276
+ else:
277
+ data = csv.reader(file, delimiter='\t')
278
+ else:
279
+ self.logger.warning(f'Skipping unsupported file type: {file_extension}')
280
+ continue
281
+
282
+ files_processed += 1
283
+
284
+ text_row = self.field_map['text']
285
+
286
+ try:
287
+ documents = self.chatbot.tagger.as_nlp_pipeline([
288
+ (
289
+ row[text_row],
290
+ {
291
+ # Include any defined metadata columns
292
+ key: row[value]
293
+ for key, value in self.field_map.items()
294
+ if key != text_row
295
+ }
296
+ ) for row in data if len(row) > 0
297
+ ])
298
+ except KeyError as e:
299
+ raise KeyError(
300
+ f'{e}. Please check the field_map parameter used to initialize '
301
+ f'the training class and remove this value if it is not needed. '
302
+ f'Current mapping: {self.field_map}'
303
+ )
304
+
305
+ for document, context in documents:
306
+ statement = Statement(
307
+ text=document.text,
308
+ conversation=context.get('conversation', 'training'),
309
+ persona=context.get('persona', None),
310
+ tags=context.get('tags', [])
311
+ )
312
+
313
+ if 'created_at' in context:
314
+ statement.created_at = date_parser.parse(context['created_at'])
315
+
316
+ statement.search_text = document._.search_index
317
+ statement.search_in_response_to = previous_statement_search_text
318
+
319
+ # Use the in_response_to attribute for the previous statement if
320
+ # one is defined, otherwise use the last statement which was created
321
+ if 'in_response_to' in self.field_map.keys():
322
+ statement.in_response_to = context.get(self.field_map['in_response_to'], None)
323
+ else:
324
+ statement.in_response_to = previous_statement_text
325
+
326
+ for preprocessor in self.chatbot.preprocessors:
327
+ statement = preprocessor(statement)
328
+
329
+ previous_statement_text = statement.text
330
+ previous_statement_search_text = statement.search_text
331
+
332
+ statements_to_create.append(statement)
333
+
334
+ self.chatbot.storage.create_many(statements_to_create)
335
+
336
+ if files_processed:
337
+ self.chatbot.logger.info(
338
+ 'Training completed. {} files were read.'.format(files_processed)
339
+ )
340
+ else:
341
+ self.chatbot.logger.warning(
342
+ 'No [{}] files were detected at: {}'.format(
343
+ self.file_extension,
344
+ data_path
345
+ )
346
+ )
347
+
348
+
349
+ class CsvFileTrainer(GenericFileTrainer):
350
+ """
351
+ .. note::
352
+ Added in version 1.2.4
353
+
354
+ Allow chatbots to be trained with data from a CSV file or
355
+ directory of CSV files.
356
+
357
+ TSV files are also supported, as long as the file_extension
358
+ parameter is set to 'tsv'.
359
+
360
+ :param str file_extension: The file extension to look for when searching for files (defaults to 'csv').
361
+ :param str field_map: A dictionary containing the database column name to header mapping.
362
+ Values can be either the header name (str) or the column index (int).
363
+ """
364
+
365
+ def __init__(self, chatbot, **kwargs):
366
+ super().__init__(chatbot, **kwargs)
367
+
368
+ self.file_extension = kwargs.get('file_extension', 'csv')
369
+
370
+
371
+ class JsonFileTrainer(GenericFileTrainer):
372
+ """
373
+ .. note::
374
+ Added in version 1.2.4
375
+
376
+ Allow chatbots to be trained with data from a JSON file or
377
+ directory of JSON files.
378
+
379
+ :param str field_map: A dictionary containing the database column name to header mapping.
380
+ """
381
+
382
+ def __init__(self, chatbot, **kwargs):
383
+ super().__init__(chatbot, **kwargs)
384
+
385
+ self.file_extension = 'json'
386
+
387
+ DEFAULT_STATEMENT_TO_KEY_MAPPING = {
388
+ 'text': 'text',
389
+ 'conversation': 'conversation',
390
+ 'created_at': 'created_at',
391
+ 'in_response_to': 'in_response_to',
392
+ 'persona': 'persona',
393
+ 'tags': 'tags'
394
+ }
395
+
396
+ self.field_map = kwargs.get(
397
+ 'field_map',
398
+ DEFAULT_STATEMENT_TO_KEY_MAPPING
399
+ )
400
+
401
+
402
+ class UbuntuCorpusTrainer(CsvFileTrainer):
403
+ """
404
+ .. note::
405
+ PENDING DEPRECATION: Please use the ``CsvFileTrainer`` for data formats similar to this one.
406
+
177
407
  Allow chatbots to be trained with the data from the Ubuntu Dialog Corpus.
178
408
 
179
409
  For more information about the Ubuntu Dialog Corpus visit:
180
410
  https://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/
411
+
412
+ :param str ubuntu_corpus_data_directory: The directory where the Ubuntu corpus data is already located, or where it should be downloaded and extracted.
181
413
  """
182
414
 
183
415
  def __init__(self, chatbot, **kwargs):
184
416
  super().__init__(chatbot, **kwargs)
185
417
  home_directory = os.path.expanduser('~')
186
418
 
187
- self.data_download_url = kwargs.get(
188
- 'ubuntu_corpus_data_download_url',
189
- 'http://cs.mcgill.ca/~jpineau/datasets/ubuntu-corpus-1.0/ubuntu_dialogs.tgz'
190
- )
419
+ self.data_download_url = None
191
420
 
192
421
  self.data_directory = kwargs.get(
193
422
  'ubuntu_corpus_data_directory',
194
423
  os.path.join(home_directory, 'ubuntu_data')
195
424
  )
196
425
 
197
- self.extracted_data_directory = os.path.join(
426
+ # Directory containing extracted data
427
+ self.data_path = os.path.join(
198
428
  self.data_directory, 'ubuntu_dialogs'
199
429
  )
200
430
 
431
+ self.field_map = {
432
+ 'text': 3,
433
+ 'created_at': 0,
434
+ 'persona': 1,
435
+ }
436
+
201
437
  def is_downloaded(self, file_path):
202
438
  """
203
439
  Check if the data file is already downloaded.
@@ -222,7 +458,6 @@ class UbuntuCorpusTrainer(Trainer):
222
458
  """
223
459
  Download a file from the given url.
224
460
  Show a progress indicator for the download status.
225
- Based on: http://stackoverflow.com/a/15645088/1547223
226
461
  """
227
462
  import requests
228
463
 
@@ -238,7 +473,8 @@ class UbuntuCorpusTrainer(Trainer):
238
473
  return file_path
239
474
 
240
475
  with open(file_path, 'wb') as open_file:
241
- print('Downloading %s' % url)
476
+ if show_status:
477
+ print('Downloading %s' % url)
242
478
  response = requests.get(url, stream=True)
243
479
  total_length = response.headers.get('content-length')
244
480
 
@@ -246,136 +482,97 @@ class UbuntuCorpusTrainer(Trainer):
246
482
  # No content length header
247
483
  open_file.write(response.content)
248
484
  else:
249
- download = 0
250
- total_length = int(total_length)
251
- for data in response.iter_content(chunk_size=4096):
252
- download += len(data)
485
+ for data in tqdm(
486
+ response.iter_content(chunk_size=4096),
487
+ desc='Downloading',
488
+ disable=not show_status
489
+ ):
253
490
  open_file.write(data)
254
- if show_status:
255
- done = int(50 * download / total_length)
256
- sys.stdout.write('\r[%s%s]' % ('=' * done, ' ' * (50 - done)))
257
- sys.stdout.flush()
258
-
259
- # Add a new line after the download bar
260
- sys.stdout.write('\n')
261
491
 
262
- print('Download location: %s' % file_path)
492
+ if show_status:
493
+ print('Download location: %s' % file_path)
263
494
  return file_path
264
495
 
265
496
  def extract(self, file_path):
266
497
  """
267
498
  Extract a tar file at the specified file path.
268
499
  """
269
- print('Extracting {}'.format(file_path))
500
+ if not self.disable_progress:
501
+ print('Extracting {}'.format(file_path))
270
502
 
271
- if not os.path.exists(self.extracted_data_directory):
272
- os.makedirs(self.extracted_data_directory)
503
+ if not os.path.exists(self.data_path):
504
+ os.makedirs(self.data_path)
273
505
 
274
- def track_progress(members):
275
- sys.stdout.write('.')
276
- for member in members:
277
- # This will be the current file being extracted
278
- yield member
506
+ def is_within_directory(directory, target):
279
507
 
280
- with tarfile.open(file_path) as tar:
281
- def is_within_directory(directory, target):
508
+ abs_directory = os.path.abspath(directory)
509
+ abs_target = os.path.abspath(target)
282
510
 
283
- abs_directory = os.path.abspath(directory)
284
- abs_target = os.path.abspath(target)
511
+ prefix = os.path.commonprefix([abs_directory, abs_target])
285
512
 
286
- prefix = os.path.commonprefix([abs_directory, abs_target])
513
+ return prefix == abs_directory
287
514
 
288
- return prefix == abs_directory
515
+ def safe_extract(tar, path='.', members=None, *, numeric_owner=False):
289
516
 
290
- def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
517
+ for member in tar.getmembers():
518
+ member_path = os.path.join(path, member.name)
519
+ if not is_within_directory(path, member_path):
520
+ raise Exception('Attempted Path Traversal in Tar File')
291
521
 
292
- for member in tar.getmembers():
293
- member_path = os.path.join(path, member.name)
294
- if not is_within_directory(path, member_path):
295
- raise Exception("Attempted Path Traversal in Tar File")
522
+ tar.extractall(path, members, numeric_owner=numeric_owner)
296
523
 
297
- tar.extractall(path, members, numeric_owner=numeric_owner)
524
+ try:
525
+ with tarfile.open(file_path, 'r') as tar:
526
+ safe_extract(tar, path=self.data_path, members=tqdm(tar, disable=self.disable_progress))
527
+ except tarfile.ReadError as e:
528
+ raise self.TrainerInitializationException(
529
+ f'The provided data file is not a valid tar file: {file_path}'
530
+ ) from e
298
531
 
299
- safe_extract(tar, path=self.extracted_data_directory, members=track_progress(tar))
300
-
301
- self.chatbot.logger.info('File extracted to {}'.format(self.extracted_data_directory))
532
+ self.chatbot.logger.info('File extracted to {}'.format(self.data_path))
302
533
 
303
534
  return True
304
-
305
- def train(self, limit=None):
535
+
536
+ def _get_file_list(self, data_path, limit):
306
537
  """
307
- limit: int If defined, the number of files to read from the data set.
538
+ Get a list of files to read from the data set.
308
539
  """
540
+
541
+ if self.data_download_url is None:
542
+ raise self.TrainerInitializationException(
543
+ 'The data_download_url attribute must be set before calling train().'
544
+ )
545
+
309
546
  # Download and extract the Ubuntu dialog corpus if needed
310
547
  corpus_download_path = self.download(self.data_download_url)
311
548
 
312
549
  # Extract if the directory does not already exist
313
- if not self.is_extracted(self.extracted_data_directory):
550
+ if not self.is_extracted(data_path):
314
551
  self.extract(corpus_download_path)
315
552
 
316
553
  extracted_corpus_path = os.path.join(
317
- self.extracted_data_directory,
318
- '**', '**', '*.tsv'
554
+ data_path, '**', '**', '*.tsv'
319
555
  )
320
556
 
321
- def chunks(items, items_per_chunk):
322
- for start_index in range(0, len(items), items_per_chunk):
323
- end_index = start_index + items_per_chunk
324
- yield items[start_index:end_index]
557
+ # Use iglob instead of glob for better performance with
558
+ # large directories because it returns an iterator
559
+ data_files = glob.iglob(extracted_corpus_path)
325
560
 
326
- file_list = glob.glob(extracted_corpus_path)
561
+ for index, file_path in enumerate(data_files):
562
+ if limit is not None and index >= limit:
563
+ break
327
564
 
328
- # Limit the number of files used if a limit is defined
329
- if limit is not None:
330
- file_list = file_list[:limit]
565
+ yield file_path
331
566
 
332
- file_groups = tuple(chunks(file_list, 5000))
567
+ def train(self, data_download_url, limit=None):
568
+ """
569
+ :param str data_download_url: The URL to download the Ubuntu dialog corpus from.
570
+ :param int limit: The maximum number of files to train from.
571
+ """
572
+ self.data_download_url = data_download_url
333
573
 
334
574
  start_time = time.time()
575
+ super().train(self.data_path, limit=limit)
335
576
 
336
- for batch_number, tsv_files in enumerate(file_groups):
337
-
338
- statements_from_file = []
339
-
340
- for tsv_file in tqdm(tsv_files, desc=f'Training with batch {batch_number} of {len(file_groups)}'):
341
- with open(tsv_file, 'r', encoding='utf-8') as tsv:
342
- reader = csv.reader(tsv, delimiter='\t')
343
-
344
- previous_statement_text = None
345
- previous_statement_search_text = ''
346
-
347
- documents = self.chatbot.tagger.as_nlp_pipeline([
348
- (
349
- row[3],
350
- {
351
- 'persona': row[1],
352
- 'created_at': row[0],
353
- }
354
- ) for row in reader if len(row) > 0
355
- ])
356
-
357
- for document, context in documents:
358
-
359
- statement_search_text = document._.search_index
360
-
361
- statement = Statement(
362
- text=document.text,
363
- in_response_to=previous_statement_text,
364
- conversation='training',
365
- created_at=date_parser.parse(context['created_at']),
366
- persona=context['persona'],
367
- search_text=statement_search_text,
368
- search_in_response_to=previous_statement_search_text
369
- )
370
-
371
- for preprocessor in self.chatbot.preprocessors:
372
- statement = preprocessor(statement)
373
-
374
- previous_statement_text = statement.text
375
- previous_statement_search_text = statement_search_text
376
-
377
- statements_from_file.append(statement)
378
-
379
- self.chatbot.storage.create_many(statements_from_file)
380
-
381
- print('Training took', time.time() - start_time, 'seconds.')
577
+ if not self.disable_progress:
578
+ print('Training took', time.time() - start_time, 'seconds.')
chatterbot/utils.py CHANGED
@@ -3,7 +3,6 @@ ChatterBot utility functions
3
3
  """
4
4
  import importlib
5
5
  import time
6
- import sys
7
6
 
8
7
 
9
8
  def import_module(dotted_path):
@@ -71,7 +70,7 @@ def validate_adapter_class(validate_class, adapter_class):
71
70
  )
72
71
 
73
72
 
74
- def get_response_time(chatbot, statement='Hello'):
73
+ def get_response_time(chatbot, statement='Hello') -> float:
75
74
  """
76
75
  Returns the amount of time taken for a given
77
76
  chat bot to return a response.
@@ -80,7 +79,6 @@ def get_response_time(chatbot, statement='Hello'):
80
79
  :type chatbot: ChatBot
81
80
 
82
81
  :returns: The response time in seconds.
83
- :rtype: float
84
82
  """
85
83
  start_time = time.time()
86
84
 
@@ -89,30 +87,21 @@ def get_response_time(chatbot, statement='Hello'):
89
87
  return time.time() - start_time
90
88
 
91
89
 
92
- def print_progress_bar(description, iteration_counter, total_items, progress_bar_length=20):
90
+ def get_model_for_language(language):
93
91
  """
94
- Print progress bar
95
- :param description: Training description
96
- :type description: str
97
-
98
- :param iteration_counter: Incremental counter
99
- :type iteration_counter: int
100
-
101
- :param total_items: total number items
102
- :type total_items: int
103
-
104
- :param progress_bar_length: Progress bar length
105
- :type progress_bar_length: int
106
-
107
- :returns: void
108
- :rtype: void
109
-
110
- DEPRECTTED: use `tqdm` instead
92
+ Returns the spacy model for the specified language.
111
93
  """
112
- percent = float(iteration_counter) / total_items
113
- hashes = '#' * int(round(percent * progress_bar_length))
114
- spaces = ' ' * (progress_bar_length - len(hashes))
115
- sys.stdout.write('\r{0}: [{1}] {2}%'.format(description, hashes + spaces, int(round(percent * 100))))
116
- sys.stdout.flush()
117
- if total_items == iteration_counter:
118
- print('\r')
94
+ from chatterbot import constants
95
+
96
+ try:
97
+ model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[language]
98
+ except KeyError as e:
99
+ if hasattr(language, 'ENGLISH_NAME'):
100
+ language_name = language.ENGLISH_NAME
101
+ else:
102
+ language_name = language
103
+ raise KeyError(
104
+ f'A corresponding spacy model for "{language_name}" could not be found.'
105
+ ) from e
106
+
107
+ return model