ChatterBot 1.2.3__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
chatterbot/__init__.py CHANGED
@@ -4,7 +4,7 @@ ChatterBot is a machine learning, conversational dialog engine.
4
4
  from .chatterbot import ChatBot
5
5
 
6
6
 
7
- __version__ = '1.2.3'
7
+ __version__ = '1.2.4'
8
8
 
9
9
  __all__ = (
10
10
  'ChatBot',
chatterbot/__main__.py CHANGED
@@ -1,7 +1,16 @@
1
+ """
2
+ Example usage for ChatterBot command line arguments:
3
+
4
+ python -m chatterbot --help
5
+ """
6
+
1
7
  import sys
2
8
 
3
9
 
4
10
  def get_chatterbot_version():
11
+ """
12
+ Return the version of the current package.
13
+ """
5
14
  from chatterbot import __version__
6
15
 
7
16
  return __version__
@@ -10,3 +19,9 @@ def get_chatterbot_version():
10
19
  if __name__ == '__main__':
11
20
  if '--version' in sys.argv:
12
21
  print(get_chatterbot_version())
22
+ elif '--help' in sys.argv:
23
+ print('usage: chatterbot [--version, --help]')
24
+ print(' --version: Print the version of ChatterBot')
25
+ print(' --help: Print this help message')
26
+ print()
27
+ print('Documentation at https://docs.chatterbot.us')
chatterbot/chatterbot.py CHANGED
@@ -11,11 +11,41 @@ import spacy
11
11
  class ChatBot(object):
12
12
  """
13
13
  A conversational dialog chat bot.
14
+
15
+ :param name: A name is the only required parameter for the ChatBot class.
16
+ :type name: str
17
+
18
+ :keyword storage_adapter: The dot-notated import path to a storage adapter class.
19
+ Defaults to ``"chatterbot.storage.SQLStorageAdapter"``.
20
+ :type storage_adapter: str
21
+
22
+ :param logic_adapters: A list of dot-notated import paths to each logic adapter the bot uses.
23
+ Defaults to ``["chatterbot.logic.BestMatch"]``.
24
+ :type logic_adapters: list
25
+
26
+ :param tagger: The tagger to use for the chat bot.
27
+ Defaults to :class:`~chatterbot.tagging.PosLemmaTagger`
28
+ :type tagger: object
29
+
30
+ :param tagger_language: The language to use for the tagger.
31
+ Defaults to :class:`~chatterbot.languages.ENG`.
32
+ :type tagger_language: object
33
+
34
+ :param preprocessors: A list of preprocessor functions to use for the chat bot.
35
+ :type preprocessors: list
36
+
37
+ :param read_only: If True, the chat bot will not save any input it receives, defaults to False.
38
+ :type read_only: bool
39
+
40
+ :param logger: A ``Logger`` object.
41
+ :type logger: logging.Logger
14
42
  """
15
43
 
16
44
  def __init__(self, name, **kwargs):
17
45
  self.name = name
18
46
 
47
+ self.logger = kwargs.get('logger', logging.getLogger(__name__))
48
+
19
49
  storage_adapter = kwargs.get('storage_adapter', 'chatterbot.storage.SQLStorageAdapter')
20
50
 
21
51
  logic_adapters = kwargs.get('logic_adapters', [
@@ -30,11 +60,29 @@ class ChatBot(object):
30
60
 
31
61
  self.storage = utils.initialize_class(storage_adapter, **kwargs)
32
62
 
33
- Tagger = kwargs.get('tagger', PosLemmaTagger)
63
+ tagger_language = kwargs.get('tagger_language', languages.ENG)
34
64
 
35
- self.tagger = Tagger(language=kwargs.get(
36
- 'tagger_language', languages.ENG
37
- ))
65
+ try:
66
+ Tagger = kwargs.get('tagger', PosLemmaTagger)
67
+
68
+ self.tagger = Tagger(language=tagger_language)
69
+ except IOError as io_error:
70
+ # Return a more helpful error message if possible
71
+ if "Can't find model" in str(io_error):
72
+ model_name = utils.get_model_for_language(tagger_language)
73
+ if hasattr(tagger_language, 'ENGLISH_NAME'):
74
+ language_name = tagger_language.ENGLISH_NAME
75
+ else:
76
+ language_name = tagger_language
77
+ raise self.ChatBotException(
78
+ 'Setup error:\n'
79
+ f'The Spacy model for "{language_name}" language is missing.\n'
80
+ 'Please install the model using the command:\n\n'
81
+ f'python -m spacy download {model_name}\n\n'
82
+ 'See https://spacy.io/usage/models for more information about available models.'
83
+ ) from io_error
84
+ else:
85
+ raise io_error
38
86
 
39
87
  primary_search_algorithm = IndexedTextSearch(self, **kwargs)
40
88
  text_search_algorithm = TextSearch(self, **kwargs)
@@ -63,8 +111,6 @@ class ChatBot(object):
63
111
  # NOTE: 'xx' is the language code for a multi-language model
64
112
  self.nlp = spacy.blank(self.tagger.language.ISO_639_1)
65
113
 
66
- self.logger = kwargs.get('logger', logging.getLogger(__name__))
67
-
68
114
  # Allow the bot to save input it receives so that it can learn
69
115
  self.read_only = kwargs.get('read_only', False)
70
116
 
chatterbot/comparisons.py CHANGED
@@ -2,7 +2,7 @@
2
2
  This module contains various text-comparison algorithms
3
3
  designed to compare one statement to another.
4
4
  """
5
- from chatterbot import constants
5
+ from chatterbot.utils import get_model_for_language
6
6
  from difflib import SequenceMatcher
7
7
  import spacy
8
8
 
@@ -100,12 +100,7 @@ class SpacySimilarity(Comparator):
100
100
  def __init__(self, language):
101
101
  super().__init__(language)
102
102
 
103
- try:
104
- model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[self.language]
105
- except KeyError as e:
106
- raise KeyError(
107
- f'Spacy model is not available for language {self.language}'
108
- ) from e
103
+ model = get_model_for_language(language)
109
104
 
110
105
  # Disable the Named Entity Recognition (NER) component because it is not necessary
111
106
  self.nlp = spacy.load(model, exclude=['ner'])
@@ -157,12 +152,7 @@ class JaccardSimilarity(Comparator):
157
152
  def __init__(self, language):
158
153
  super().__init__(language)
159
154
 
160
- try:
161
- model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[self.language]
162
- except KeyError as e:
163
- raise KeyError(
164
- f'Spacy model is not available for language {self.language}'
165
- ) from e
155
+ model = get_model_for_language(language)
166
156
 
167
157
  # Disable the Named Entity Recognition (NER) component because it is not necessary
168
158
  self.nlp = spacy.load(model, exclude=['ner'])
@@ -1,7 +1,7 @@
1
1
  from sqlalchemy import Table, Column, Integer, String, DateTime, ForeignKey
2
- from sqlalchemy.orm import relationship
2
+ from sqlalchemy.orm import relationship, declarative_base
3
3
  from sqlalchemy.sql import func
4
- from sqlalchemy.ext.declarative import declared_attr, declarative_base
4
+ from sqlalchemy.ext.declarative import declared_attr
5
5
 
6
6
  from chatterbot.conversation import StatementMixin
7
7
  from chatterbot import constants
@@ -1,6 +1,7 @@
1
1
  from chatterbot.logic import LogicAdapter
2
2
  from chatterbot.conversation import Statement
3
- from chatterbot import constants, languages
3
+ from chatterbot import languages
4
+ from chatterbot.utils import get_model_for_language
4
5
  import spacy
5
6
 
6
7
 
@@ -35,12 +36,7 @@ class SpecificResponseAdapter(LogicAdapter):
35
36
  self._output_text = kwargs.get('output_text')
36
37
 
37
38
  def _initialize_nlp(self, language):
38
- try:
39
- model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[language]
40
- except KeyError as e:
41
- raise KeyError(
42
- f'Spacy model is not available for language {language}'
43
- ) from e
39
+ model = get_model_for_language(language)
44
40
 
45
41
  return spacy.load(model)
46
42
 
@@ -1,7 +1,8 @@
1
1
  from datetime import datetime
2
- from chatterbot import constants, languages
2
+ from chatterbot import languages
3
3
  from chatterbot.logic import LogicAdapter
4
4
  from chatterbot.conversation import Statement
5
+ from chatterbot.utils import get_model_for_language
5
6
  import spacy
6
7
 
7
8
 
@@ -36,12 +37,7 @@ class TimeLogicAdapter(LogicAdapter):
36
37
 
37
38
  language = kwargs.get('language', languages.ENG)
38
39
 
39
- try:
40
- model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[language]
41
- except KeyError as e:
42
- raise KeyError(
43
- f'Spacy model is not available for language {language}'
44
- ) from e
40
+ model = get_model_for_language(language)
45
41
 
46
42
  self.nlp = spacy.load(model)
47
43
 
chatterbot/search.py CHANGED
@@ -149,58 +149,3 @@ class TextSearch:
149
149
  ))
150
150
 
151
151
  yield statement
152
-
153
-
154
- class VectorSearch:
155
- """
156
- .. note:: BETA feature: this search method is new and experimental.
157
-
158
- Search for similar text based on a :term:`vector database`.
159
- """
160
-
161
- name = 'vector_search'
162
-
163
- def __init__(self, chatbot, **kwargs):
164
- from chatterbot.storage import RedisVectorStorageAdapter
165
-
166
- # Good documentation:
167
- # https://python.langchain.com/docs/integrations/vectorstores/redis/
168
- #
169
- # https://hub.docker.com/r/redis/redis-stack
170
-
171
- # Mondodb:
172
- # > Vector Search is only supported on Atlas Clusters
173
- # https://www.mongodb.com/community/forums/t/can-a-local-mongodb-instance-be-used-when-working-with-langchain-mongodbatlasvectorsearch/265356
174
-
175
- # FAISS:
176
- # https://python.langchain.com/docs/integrations/vectorstores/faiss/
177
-
178
- print("Starting Redis Vector Store")
179
-
180
- # TODO: look into:
181
- # https://python.langchain.com/api_reference/redis/chat_message_history/langchain_redis.chat_message_history.RedisChatMessageHistory.html
182
-
183
- # The VectorSearch class is only compatible with the RedisVectorStorageAdapter
184
- if not isinstance(chatbot.storage, RedisVectorStorageAdapter):
185
- raise Exception(
186
- 'The VectorSearch search method requires the RedisVectorStorageAdapter storage adapter.'
187
- )
188
-
189
- def search(self, input_statement, **additional_parameters):
190
- print("Querying Vector Store")
191
-
192
- # Similarity search with score and filter
193
- # NOTE: It looks like `return_all` is needed to return the full document
194
- # specifically what we need here is the ID
195
- scored_results = self.storage.vector_store.similarity_search_with_score(
196
- input_statement.text, k=2, return_all=True
197
- )
198
- # sort_by="score", filter={"category": "likes"})
199
-
200
- print("Similarity Search with Score Results:\n")
201
- for doc, score in scored_results:
202
- print(f"Content: {doc.page_content[:150]}...")
203
- print(f"ID: {doc.id}")
204
- print(f"Metadata: {doc.metadata}")
205
- print(f"Score: {score}")
206
- print()
@@ -326,7 +326,7 @@ class SQLStorageAdapter(StorageAdapter):
326
326
  record = None
327
327
 
328
328
  if hasattr(statement, 'id') and statement.id is not None:
329
- record = session.query(Statement).get(statement.id)
329
+ record = session.get(Statement, statement.id)
330
330
  else:
331
331
  record = session.query(Statement).filter(
332
332
  Statement.text == statement.text,
chatterbot/tagging.py CHANGED
@@ -1,4 +1,5 @@
1
- from chatterbot import languages, constants
1
+ from chatterbot import languages
2
+ from chatterbot.utils import get_model_for_language
2
3
  import spacy
3
4
 
4
5
 
@@ -42,12 +43,7 @@ class PosLemmaTagger(object):
42
43
 
43
44
  self.language = language or languages.ENG
44
45
 
45
- try:
46
- model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[self.language]
47
- except KeyError as e:
48
- raise KeyError(
49
- f'Spacy model is not available for language {self.language}'
50
- ) from e
46
+ model = get_model_for_language(self.language)
51
47
 
52
48
  # Disable the Named Entity Recognition (NER) component because it is not necessary
53
49
  self.nlp = spacy.load(model, exclude=['ner'])
chatterbot/trainers.py CHANGED
@@ -25,7 +25,7 @@ class Trainer(object):
25
25
 
26
26
  environment_default = bool(int(os.environ.get('CHATTERBOT_SHOW_TRAINING_PROGRESS', True)))
27
27
 
28
- self.show_training_progress = kwargs.get(
28
+ self.disable_progress = not kwargs.get(
29
29
  'show_training_progress',
30
30
  environment_default
31
31
  )
@@ -54,7 +54,7 @@ class Trainer(object):
54
54
  def __init__(self, message=None):
55
55
  default = (
56
56
  'A training class must be specified before calling train(). '
57
- 'See https://docs.chatterbot.us/training.html'
57
+ 'See https://docs.chatterbot.us/training/'
58
58
  )
59
59
  super().__init__(message or default)
60
60
 
@@ -82,7 +82,7 @@ class ListTrainer(Trainer):
82
82
  where the list represents a conversation.
83
83
  """
84
84
 
85
- def train(self, conversation):
85
+ def train(self, conversation: list):
86
86
  """
87
87
  Train the chat bot based on the provided list of
88
88
  statements that represents a single conversation.
@@ -96,7 +96,7 @@ class ListTrainer(Trainer):
96
96
  documents = self.chatbot.tagger.as_nlp_pipeline(conversation)
97
97
 
98
98
  # for text in enumerate(conversation):
99
- for document in tqdm(documents, desc='List Trainer', disable=not self.show_training_progress):
99
+ for document in tqdm(documents, desc='List Trainer', disable=self.disable_progress):
100
100
  statement_search_text = document._.search_index
101
101
 
102
102
  statement = self.get_preprocessed_statement(
@@ -135,7 +135,7 @@ class ChatterBotCorpusTrainer(Trainer):
135
135
  for corpus, categories, _file_path in tqdm(
136
136
  load_corpus(*data_file_paths),
137
137
  desc='ChatterBot Corpus Trainer',
138
- disable=not self.show_training_progress
138
+ disable=self.disable_progress
139
139
  ):
140
140
  statements_to_create = []
141
141
 
@@ -172,32 +172,259 @@ class ChatterBotCorpusTrainer(Trainer):
172
172
  self.chatbot.storage.create_many(statements_to_create)
173
173
 
174
174
 
175
- class UbuntuCorpusTrainer(Trainer):
175
+ class GenericFileTrainer(Trainer):
176
176
  """
177
+ Allows the chat bot to be trained using data from a CSV or JSON file,
178
+ or directory of those file types.
179
+ """
180
+
181
+ def __init__(self, chatbot, **kwargs):
182
+ """
183
+ data_path: str The path to the data file or directory.
184
+ field_map: dict A dictionary containing the column name to header mapping.
185
+ """
186
+ super().__init__(chatbot, **kwargs)
187
+
188
+ self.file_extension = None
189
+
190
+ # NOTE: If the key is an integer, this be the
191
+ # column index instead of the key or header
192
+ DEFAULT_STATEMENT_TO_HEADER_MAPPING = {
193
+ 'text': 'text',
194
+ 'conversation': 'conversation',
195
+ 'created_at': 'created_at',
196
+ 'persona': 'persona',
197
+ 'tags': 'tags'
198
+ }
199
+
200
+ self.field_map = kwargs.get(
201
+ 'field_map',
202
+ DEFAULT_STATEMENT_TO_HEADER_MAPPING
203
+ )
204
+
205
+ def _get_file_list(self, data_path, limit):
206
+ """
207
+ Get a list of files to read from the data set.
208
+ """
209
+
210
+ if self.file_extension is None:
211
+ raise self.TrainerInitializationException(
212
+ 'The file_extension attribute must be set before calling train().'
213
+ )
214
+
215
+ # List all csv or json files in the specified directory
216
+ if os.path.isdir(data_path):
217
+ glob_path = os.path.join(data_path, '**', f'*.{self.file_extension}')
218
+
219
+ # Use iglob instead of glob for better performance with
220
+ # large directories because it returns an iterator
221
+ data_files = glob.iglob(glob_path, recursive=True)
222
+
223
+ for index, file_path in enumerate(data_files):
224
+ if limit is not None and index >= limit:
225
+ break
226
+
227
+ yield file_path
228
+ else:
229
+ return [data_path]
230
+
231
+ def train(self, data_path: str, limit=None):
232
+ """
233
+ Train a chatbot with data from the data file.
234
+
235
+ :param str data_path: The path to the data file or directory.
236
+ :param int limit: The maximum number of files to train from.
237
+ """
238
+
239
+ if data_path is None:
240
+ raise self.TrainerInitializationException(
241
+ 'The data_path argument must be set to the path of a file or directory.'
242
+ )
243
+
244
+ data_files = self._get_file_list(data_path, limit)
245
+
246
+ files_processed = 0
247
+
248
+ for data_file in tqdm(data_files, desc='Training', disable=self.disable_progress):
249
+
250
+ previous_statement_text = None
251
+ previous_statement_search_text = ''
252
+
253
+ file_extension = data_file.split('.')[-1].lower()
254
+
255
+ statements_to_create = []
256
+
257
+ with open(data_file, 'r', encoding='utf-8') as file:
258
+
259
+ if self.file_extension == 'json':
260
+ data = json.load(file)
261
+ data = data['conversation']
262
+ elif file_extension == 'csv':
263
+ use_header = bool(isinstance(next(iter(self.field_map.values())), str))
264
+
265
+ if use_header:
266
+ data = csv.DictReader(file)
267
+ else:
268
+ data = csv.reader(file)
269
+ elif file_extension == 'tsv':
270
+ use_header = bool(isinstance(next(iter(self.field_map.values())), str))
271
+
272
+ if use_header:
273
+ data = csv.DictReader(file, delimiter='\t')
274
+ else:
275
+ data = csv.reader(file, delimiter='\t')
276
+ else:
277
+ self.logger.warning(f'Skipping unsupported file type: {file_extension}')
278
+ continue
279
+
280
+ files_processed += 1
281
+
282
+ text_row = self.field_map['text']
283
+
284
+ documents = self.chatbot.tagger.as_nlp_pipeline([
285
+ (
286
+ row[text_row],
287
+ {
288
+ # Include any defined metadata columns
289
+ key: row[value]
290
+ for key, value in self.field_map.items()
291
+ if key != text_row
292
+ }
293
+ ) for row in data if len(row) > 0
294
+ ])
295
+
296
+ for document, context in documents:
297
+ statement = Statement(
298
+ text=document.text,
299
+ conversation=context.get('conversation', 'training'),
300
+ persona=context.get('persona', None),
301
+ tags=context.get('tags', [])
302
+ )
303
+
304
+ if 'created_at' in context:
305
+ statement.created_at = date_parser.parse(context['created_at'])
306
+
307
+ statement.search_text = document._.search_index
308
+ statement.search_in_response_to = previous_statement_search_text
309
+
310
+ # Use the in_response_to attribute for the previous statement if
311
+ # one is defined, otherwise use the last statement which was created
312
+ if 'in_response_to' in self.field_map.keys():
313
+ statement.in_response_to = context.get(self.field_map['in_response_to'], None)
314
+ else:
315
+ statement.in_response_to = previous_statement_text
316
+
317
+ for preprocessor in self.chatbot.preprocessors:
318
+ statement = preprocessor(statement)
319
+
320
+ previous_statement_text = statement.text
321
+ previous_statement_search_text = statement.search_text
322
+
323
+ statements_to_create.append(statement)
324
+
325
+ self.chatbot.storage.create_many(statements_to_create)
326
+
327
+ if files_processed:
328
+ self.chatbot.logger.info(
329
+ 'Training completed. {} files were read.'.format(files_processed)
330
+ )
331
+ else:
332
+ self.chatbot.logger.warning(
333
+ 'No [{}] files were detected at: {}'.format(
334
+ self.file_extension,
335
+ data_path
336
+ )
337
+ )
338
+
339
+
340
+ class CsvFileTrainer(GenericFileTrainer):
341
+ """
342
+ .. note::
343
+ Added in version 1.2.4
344
+
345
+ Allow chatbots to be trained with data from a CSV file or
346
+ directory of CSV files.
347
+
348
+ TSV files are also supported, as long as the file_extension
349
+ parameter is set to 'tsv'.
350
+
351
+ :param str file_extension: The file extension to look for when searching for files (defaults to 'csv').
352
+ :param str field_map: A dictionary containing the database column name to header mapping.
353
+ Values can be either the header name (str) or the column index (int).
354
+ """
355
+
356
+ def __init__(self, chatbot, **kwargs):
357
+ super().__init__(chatbot, **kwargs)
358
+
359
+ self.file_extension = kwargs.get('file_extension', 'csv')
360
+
361
+
362
+ class JsonFileTrainer(GenericFileTrainer):
363
+ """
364
+ .. note::
365
+ Added in version 1.2.4
366
+
367
+ Allow chatbots to be trained with data from a JSON file or
368
+ directory of JSON files.
369
+
370
+ :param str field_map: A dictionary containing the database column name to header mapping.
371
+ """
372
+
373
+ def __init__(self, chatbot, **kwargs):
374
+ super().__init__(chatbot, **kwargs)
375
+
376
+ self.file_extension = 'json'
377
+
378
+ DEFAULT_STATEMENT_TO_KEY_MAPPING = {
379
+ 'text': 'text',
380
+ 'conversation': 'conversation',
381
+ 'created_at': 'created_at',
382
+ 'in_response_to': 'in_response_to',
383
+ 'persona': 'persona',
384
+ 'tags': 'tags'
385
+ }
386
+
387
+ self.field_map = kwargs.get(
388
+ 'field_map',
389
+ DEFAULT_STATEMENT_TO_KEY_MAPPING
390
+ )
391
+
392
+
393
+ class UbuntuCorpusTrainer(CsvFileTrainer):
394
+ """
395
+ .. note::
396
+ PENDING DEPRECATION: Please use the ``CsvFileTrainer`` for data formats similar to this one.
397
+
177
398
  Allow chatbots to be trained with the data from the Ubuntu Dialog Corpus.
178
399
 
179
400
  For more information about the Ubuntu Dialog Corpus visit:
180
401
  https://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/
402
+
403
+ :param str ubuntu_corpus_data_directory: The directory where the Ubuntu corpus data is already located, or where it should be downloaded and extracted.
181
404
  """
182
405
 
183
406
  def __init__(self, chatbot, **kwargs):
184
407
  super().__init__(chatbot, **kwargs)
185
408
  home_directory = os.path.expanduser('~')
186
409
 
187
- self.data_download_url = kwargs.get(
188
- 'ubuntu_corpus_data_download_url',
189
- 'http://cs.mcgill.ca/~jpineau/datasets/ubuntu-corpus-1.0/ubuntu_dialogs.tgz'
190
- )
410
+ self.data_download_url = None
191
411
 
192
412
  self.data_directory = kwargs.get(
193
413
  'ubuntu_corpus_data_directory',
194
414
  os.path.join(home_directory, 'ubuntu_data')
195
415
  )
196
416
 
197
- self.extracted_data_directory = os.path.join(
417
+ # Directory containing extracted data
418
+ self.data_path = os.path.join(
198
419
  self.data_directory, 'ubuntu_dialogs'
199
420
  )
200
421
 
422
+ self.field_map = {
423
+ 'text': 3,
424
+ 'created_at': 0,
425
+ 'persona': 1,
426
+ }
427
+
201
428
  def is_downloaded(self, file_path):
202
429
  """
203
430
  Check if the data file is already downloaded.
@@ -222,7 +449,6 @@ class UbuntuCorpusTrainer(Trainer):
222
449
  """
223
450
  Download a file from the given url.
224
451
  Show a progress indicator for the download status.
225
- Based on: http://stackoverflow.com/a/15645088/1547223
226
452
  """
227
453
  import requests
228
454
 
@@ -238,7 +464,8 @@ class UbuntuCorpusTrainer(Trainer):
238
464
  return file_path
239
465
 
240
466
  with open(file_path, 'wb') as open_file:
241
- print('Downloading %s' % url)
467
+ if show_status:
468
+ print('Downloading %s' % url)
242
469
  response = requests.get(url, stream=True)
243
470
  total_length = response.headers.get('content-length')
244
471
 
@@ -246,136 +473,97 @@ class UbuntuCorpusTrainer(Trainer):
246
473
  # No content length header
247
474
  open_file.write(response.content)
248
475
  else:
249
- download = 0
250
- total_length = int(total_length)
251
- for data in response.iter_content(chunk_size=4096):
252
- download += len(data)
476
+ for data in tqdm(
477
+ response.iter_content(chunk_size=4096),
478
+ desc='Downloading',
479
+ disable=not show_status
480
+ ):
253
481
  open_file.write(data)
254
- if show_status:
255
- done = int(50 * download / total_length)
256
- sys.stdout.write('\r[%s%s]' % ('=' * done, ' ' * (50 - done)))
257
- sys.stdout.flush()
258
482
 
259
- # Add a new line after the download bar
260
- sys.stdout.write('\n')
261
-
262
- print('Download location: %s' % file_path)
483
+ if show_status:
484
+ print('Download location: %s' % file_path)
263
485
  return file_path
264
486
 
265
487
  def extract(self, file_path):
266
488
  """
267
489
  Extract a tar file at the specified file path.
268
490
  """
269
- print('Extracting {}'.format(file_path))
270
-
271
- if not os.path.exists(self.extracted_data_directory):
272
- os.makedirs(self.extracted_data_directory)
491
+ if not self.disable_progress:
492
+ print('Extracting {}'.format(file_path))
273
493
 
274
- def track_progress(members):
275
- sys.stdout.write('.')
276
- for member in members:
277
- # This will be the current file being extracted
278
- yield member
494
+ if not os.path.exists(self.data_path):
495
+ os.makedirs(self.data_path)
279
496
 
280
- with tarfile.open(file_path) as tar:
281
- def is_within_directory(directory, target):
497
+ def is_within_directory(directory, target):
282
498
 
283
- abs_directory = os.path.abspath(directory)
284
- abs_target = os.path.abspath(target)
499
+ abs_directory = os.path.abspath(directory)
500
+ abs_target = os.path.abspath(target)
285
501
 
286
- prefix = os.path.commonprefix([abs_directory, abs_target])
502
+ prefix = os.path.commonprefix([abs_directory, abs_target])
287
503
 
288
- return prefix == abs_directory
504
+ return prefix == abs_directory
289
505
 
290
- def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
506
+ def safe_extract(tar, path='.', members=None, *, numeric_owner=False):
291
507
 
292
- for member in tar.getmembers():
293
- member_path = os.path.join(path, member.name)
294
- if not is_within_directory(path, member_path):
295
- raise Exception("Attempted Path Traversal in Tar File")
508
+ for member in tar.getmembers():
509
+ member_path = os.path.join(path, member.name)
510
+ if not is_within_directory(path, member_path):
511
+ raise Exception('Attempted Path Traversal in Tar File')
296
512
 
297
- tar.extractall(path, members, numeric_owner=numeric_owner)
513
+ tar.extractall(path, members, numeric_owner=numeric_owner)
298
514
 
299
- safe_extract(tar, path=self.extracted_data_directory, members=track_progress(tar))
515
+ try:
516
+ with tarfile.open(file_path, 'r') as tar:
517
+ safe_extract(tar, path=self.data_path, members=tqdm(tar, disable=self.disable_progress))
518
+ except tarfile.ReadError as e:
519
+ raise self.TrainerInitializationException(
520
+ f'The provided data file is not a valid tar file: {file_path}'
521
+ ) from e
300
522
 
301
- self.chatbot.logger.info('File extracted to {}'.format(self.extracted_data_directory))
523
+ self.chatbot.logger.info('File extracted to {}'.format(self.data_path))
302
524
 
303
525
  return True
304
-
305
- def train(self, limit=None):
526
+
527
+ def _get_file_list(self, data_path, limit):
306
528
  """
307
- limit: int If defined, the number of files to read from the data set.
529
+ Get a list of files to read from the data set.
308
530
  """
531
+
532
+ if self.data_download_url is None:
533
+ raise self.TrainerInitializationException(
534
+ 'The data_download_url attribute must be set before calling train().'
535
+ )
536
+
309
537
  # Download and extract the Ubuntu dialog corpus if needed
310
538
  corpus_download_path = self.download(self.data_download_url)
311
539
 
312
540
  # Extract if the directory does not already exist
313
- if not self.is_extracted(self.extracted_data_directory):
541
+ if not self.is_extracted(data_path):
314
542
  self.extract(corpus_download_path)
315
543
 
316
544
  extracted_corpus_path = os.path.join(
317
- self.extracted_data_directory,
318
- '**', '**', '*.tsv'
545
+ data_path, '**', '**', '*.tsv'
319
546
  )
320
547
 
321
- def chunks(items, items_per_chunk):
322
- for start_index in range(0, len(items), items_per_chunk):
323
- end_index = start_index + items_per_chunk
324
- yield items[start_index:end_index]
548
+ # Use iglob instead of glob for better performance with
549
+ # large directories because it returns an iterator
550
+ data_files = glob.iglob(extracted_corpus_path)
325
551
 
326
- file_list = glob.glob(extracted_corpus_path)
552
+ for index, file_path in enumerate(data_files):
553
+ if limit is not None and index >= limit:
554
+ break
327
555
 
328
- # Limit the number of files used if a limit is defined
329
- if limit is not None:
330
- file_list = file_list[:limit]
556
+ yield file_path
331
557
 
332
- file_groups = tuple(chunks(file_list, 5000))
558
+ def train(self, data_download_url, limit=None):
559
+ """
560
+ :param str data_download_url: The URL to download the Ubuntu dialog corpus from.
561
+ :param int limit: The maximum number of files to train from.
562
+ """
563
+ self.data_download_url = data_download_url
333
564
 
334
565
  start_time = time.time()
566
+ super().train(self.data_path, limit=limit)
335
567
 
336
- for batch_number, tsv_files in enumerate(file_groups):
337
-
338
- statements_from_file = []
339
-
340
- for tsv_file in tqdm(tsv_files, desc=f'Training with batch {batch_number} of {len(file_groups)}'):
341
- with open(tsv_file, 'r', encoding='utf-8') as tsv:
342
- reader = csv.reader(tsv, delimiter='\t')
343
-
344
- previous_statement_text = None
345
- previous_statement_search_text = ''
346
-
347
- documents = self.chatbot.tagger.as_nlp_pipeline([
348
- (
349
- row[3],
350
- {
351
- 'persona': row[1],
352
- 'created_at': row[0],
353
- }
354
- ) for row in reader if len(row) > 0
355
- ])
356
-
357
- for document, context in documents:
358
-
359
- statement_search_text = document._.search_index
360
-
361
- statement = Statement(
362
- text=document.text,
363
- in_response_to=previous_statement_text,
364
- conversation='training',
365
- created_at=date_parser.parse(context['created_at']),
366
- persona=context['persona'],
367
- search_text=statement_search_text,
368
- search_in_response_to=previous_statement_search_text
369
- )
370
-
371
- for preprocessor in self.chatbot.preprocessors:
372
- statement = preprocessor(statement)
373
-
374
- previous_statement_text = statement.text
375
- previous_statement_search_text = statement_search_text
376
-
377
- statements_from_file.append(statement)
378
-
379
- self.chatbot.storage.create_many(statements_from_file)
380
-
381
- print('Training took', time.time() - start_time, 'seconds.')
568
+ if not self.disable_progress:
569
+ print('Training took', time.time() - start_time, 'seconds.')
chatterbot/utils.py CHANGED
@@ -89,30 +89,21 @@ def get_response_time(chatbot, statement='Hello'):
89
89
  return time.time() - start_time
90
90
 
91
91
 
92
- def print_progress_bar(description, iteration_counter, total_items, progress_bar_length=20):
92
+ def get_model_for_language(language):
93
93
  """
94
- Print progress bar
95
- :param description: Training description
96
- :type description: str
97
-
98
- :param iteration_counter: Incremental counter
99
- :type iteration_counter: int
100
-
101
- :param total_items: total number items
102
- :type total_items: int
103
-
104
- :param progress_bar_length: Progress bar length
105
- :type progress_bar_length: int
106
-
107
- :returns: void
108
- :rtype: void
109
-
110
- DEPRECTTED: use `tqdm` instead
94
+ Returns the spacy model for the specified language.
111
95
  """
112
- percent = float(iteration_counter) / total_items
113
- hashes = '#' * int(round(percent * progress_bar_length))
114
- spaces = ' ' * (progress_bar_length - len(hashes))
115
- sys.stdout.write('\r{0}: [{1}] {2}%'.format(description, hashes + spaces, int(round(percent * 100))))
116
- sys.stdout.flush()
117
- if total_items == iteration_counter:
118
- print('\r')
96
+ from chatterbot import constants
97
+
98
+ try:
99
+ model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[language]
100
+ except KeyError as e:
101
+ if hasattr(language, 'ENGLISH_NAME'):
102
+ language_name = language.ENGLISH_NAME
103
+ else:
104
+ language_name = language
105
+ raise KeyError(
106
+ f'A corresponding spacy model for "{language_name}" could not be found.'
107
+ ) from e
108
+
109
+ return model
@@ -1,27 +1,15 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: ChatterBot
3
- Version: 1.2.3
3
+ Version: 1.2.4
4
4
  Summary: ChatterBot is a machine learning, conversational dialog engine
5
5
  Author: Gunther Cox
6
- License: Copyright (c) 2016 - 2025, Gunther Cox
7
- All rights reserved.
8
-
9
- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
10
-
11
- * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
12
-
13
- * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
14
-
15
- * Neither the name of ChatterBot nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
16
-
17
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
18
-
6
+ License-Expression: BSD-3-Clause
19
7
  Project-URL: Documentation, https://docs.chatterbot.us
20
8
  Project-URL: Repository, https://github.com/gunthercox/ChatterBot
9
+ Project-URL: Changelog, https://github.com/gunthercox/ChatterBot/releases
21
10
  Keywords: ChatterBot,chatbot,chat,bot,natural language processing,nlp,artificial intelligence,ai
22
11
  Classifier: Development Status :: 4 - Beta
23
12
  Classifier: Intended Audience :: Developers
24
- Classifier: License :: OSI Approved :: BSD License
25
13
  Classifier: Operating System :: OS Independent
26
14
  Classifier: Environment :: Console
27
15
  Classifier: Environment :: Web Environment
@@ -40,7 +28,7 @@ Classifier: Programming Language :: Python
40
28
  Classifier: Programming Language :: Python :: 3
41
29
  Classifier: Programming Language :: Python :: 3.9
42
30
  Classifier: Programming Language :: Python :: 3 :: Only
43
- Requires-Python: ~=3.9
31
+ Requires-Python: <3.13,>=3.9
44
32
  Description-Content-Type: text/markdown
45
33
  License-File: LICENSE
46
34
  Requires-Dist: mathparse<0.2,>=0.1
@@ -54,6 +42,7 @@ Requires-Dist: coverage; extra == "test"
54
42
  Requires-Dist: nose; extra == "test"
55
43
  Requires-Dist: sphinx<8.2,>=5.3; extra == "test"
56
44
  Requires-Dist: sphinx-sitemap>=2.6.0; extra == "test"
45
+ Requires-Dist: huggingface_hub; extra == "test"
57
46
  Provides-Extra: dev
58
47
  Requires-Dist: pint>=0.8.1; extra == "dev"
59
48
  Requires-Dist: pyyaml<7.0,>=6.0; extra == "dev"
@@ -65,6 +54,7 @@ Requires-Dist: langchain-huggingface; extra == "redis"
65
54
  Requires-Dist: sentence-transformers; extra == "redis"
66
55
  Provides-Extra: mongodb
67
56
  Requires-Dist: pymongo<4.12,>=4.11; extra == "mongodb"
57
+ Dynamic: license-file
68
58
 
69
59
  ![ChatterBot: Machine learning in Python](https://i.imgur.com/b3SCmGT.png)
70
60
 
@@ -167,7 +157,7 @@ See release notes for changes https://github.com/gunthercox/ChatterBot/releases
167
157
  a new branch `my-pull-request`.
168
158
  3. [Create a pull request](https://help.github.com/articles/creating-a-pull-request/).
169
159
  4. Please follow the [Python style guide for PEP-8](https://www.python.org/dev/peps/pep-0008/).
170
- 5. Use the projects [built-in automated testing](https://docs.chatterbot.us/testing.html).
160
+ 5. Use the projects [built-in automated testing](https://docs.chatterbot.us/testing/).
171
161
  to help make sure that your contribution is free from errors.
172
162
 
173
163
  # License
@@ -1,8 +1,8 @@
1
- chatterbot/__init__.py,sha256=PhC2oXazQN3HNYXvWb33IAqEzwyt5QtqcfESq8eg3sg,158
2
- chatterbot/__main__.py,sha256=nk19D56TlPT9Zdqkq4qZZrOnLKEc4YTwUVWmXYwSyHg,207
1
+ chatterbot/__init__.py,sha256=Woq2bFnaAs8yTE2HVPsxXEyzFrXs1njsGJnJVgbYGvI,158
2
+ chatterbot/__main__.py,sha256=zvH4uxtGlGrP-ht_LkhX29duzjm3hRH800SDCq4YOwg,637
3
3
  chatterbot/adapters.py,sha256=LJ_KqLpHKPdYAFpMGK63RVH4weV5X0Zh5uGyan6qdVU,878
4
- chatterbot/chatterbot.py,sha256=YLKLkQ-XI4Unr3rbzjpGIupOqenuevm21tAnx-yFFgQ,10400
5
- chatterbot/comparisons.py,sha256=8-qLFWC1Z7tZ3iPUpyY6AD9l-whSo3QE1Rno_SzIp-I,6570
4
+ chatterbot/chatterbot.py,sha256=BW_XQK78iOvc0fZ8EsEglNUdjyRE2lxUI_sP-fa4gCc,12505
5
+ chatterbot/comparisons.py,sha256=8kkjW-lhS-57XSUlQI5B-dAdJO-CvkIirWLBKtbe4gw,6187
6
6
  chatterbot/components.py,sha256=ld3Xam8olBClvE5QqcFYggE7Q7tODCFek7BO7lhfyeU,1782
7
7
  chatterbot/constants.py,sha256=c_KPQKc82CHX6H3maeyTYqWatx6j-N-8HJhmejoVi60,1875
8
8
  chatterbot/conversation.py,sha256=Y-WOxPN7I3igRyAEe5py1sfS6JIYPdbwjVlY3kM8Ys8,3175
@@ -13,10 +13,10 @@ chatterbot/languages.py,sha256=XSenfc5FxHk_JWG5gGHsZvjvrPBbCaVCm_OU-BeER_M,32784
13
13
  chatterbot/parsing.py,sha256=vS-w70cMkjq4YEpDOv_pXWhAI6Zj06WYDAcMDhYDj0M,23174
14
14
  chatterbot/preprocessors.py,sha256=aI4v987dZc7GOKhO43i0i73EX748hehYSpzikFHpEXs,1271
15
15
  chatterbot/response_selection.py,sha256=aYeZ54jpGIcQnI-1-TDcua_f1p3PiM5_iMg4hF5ZaIU,2951
16
- chatterbot/search.py,sha256=S4MFL1JzPqT-pv7tCgd-MIqf0T9Ia_KOLoNgzdoCP4Y,7035
17
- chatterbot/tagging.py,sha256=GLY9wg_rvn6pSYVML-HcxkIo_3BZ3SAyj-q1oNZY8pI,2584
18
- chatterbot/trainers.py,sha256=U1yh0_V7FFL51MeQe1P1Q59weceDbkHh_2kDiDYpSEc,13315
19
- chatterbot/utils.py,sha256=ckQXvsjp2FO9GcWxziY67JovN7mShnE4RlzdYarQY5k,3277
16
+ chatterbot/search.py,sha256=FTwwON2eKPWqoc5uoKh4AUmuXDCqyfMcMcXB4wijpxg,4910
17
+ chatterbot/tagging.py,sha256=czcI2g18vILujphkjvobRyEewJU8-QjS7QRzY-hCZ4o,2429
18
+ chatterbot/trainers.py,sha256=9mxi1_UmtiuXuEzpn4uztnV8PObD0Xt0PrAbTZ6oyt0,19294
19
+ chatterbot/utils.py,sha256=tGmUt-KYYylD2fiG_oq_XxhGbAHukzwudZ_6hNuraIA,2944
20
20
  chatterbot/vectorstores.py,sha256=-S1NB8PrZzoFIu95n2W7N4UaXuCUpyDUXIGYFebjv08,2056
21
21
  chatterbot/ext/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  chatterbot/ext/django_chatterbot/__init__.py,sha256=iWzmBzpAsYwkwi1faxAPFY9L1bbL97RgVXK2uqULIMc,92
@@ -47,22 +47,22 @@ chatterbot/ext/django_chatterbot/migrations/0018_text_max_length.py,sha256=508Tx
47
47
  chatterbot/ext/django_chatterbot/migrations/0019_alter_statement_id_alter_tag_id_and_more.py,sha256=rsVxwDFMQ-cU1KMhjDq9Wcl_6gTPKc_dc3p-gv_R7v8,999
48
48
  chatterbot/ext/django_chatterbot/migrations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
49
49
  chatterbot/ext/sqlalchemy_app/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
- chatterbot/ext/sqlalchemy_app/models.py,sha256=pjU4e2BUSitw_IAkrk4iFQ9pZRU35y5MomvX7aiBFCw,2492
50
+ chatterbot/ext/sqlalchemy_app/models.py,sha256=ZQ-R_5rA-f1agaqYGUQhuuO7zx__BvTDUvJo5R7ZrDY,2492
51
51
  chatterbot/logic/__init__.py,sha256=28-5swBCPfSVMl8xB5C8frOKZ2oj28rQfenbd9E4r-4,531
52
52
  chatterbot/logic/best_match.py,sha256=8TNW0uZ_Uq-XPfaZUMUZDVH6KzDT65j59xblxQBv-dQ,4820
53
53
  chatterbot/logic/logic_adapter.py,sha256=5kNEirh5fiF5hhSMFXD7bIkKwXHmrSsSS4qDm-6xry0,4694
54
54
  chatterbot/logic/mathematical_evaluation.py,sha256=GPDKUwNFajERof2R-MkPGi2jJRP-rKAGm_f0V9JHDHE,2282
55
- chatterbot/logic/specific_response.py,sha256=o17YIeu9DzucO8MXMP3kwNIBb1b8br60bbAhSE7AZWc,2386
56
- chatterbot/logic/time_adapter.py,sha256=mxdoQGeC5IjREH4PU5iHYOIPEvnYnzgysocR8xMYWXc,2406
55
+ chatterbot/logic/specific_response.py,sha256=akWHkfe0AjzlCUvjs_PbKFNkX4SZhu_tzY45xCRXoo0,2236
56
+ chatterbot/logic/time_adapter.py,sha256=1PT6tWtGauZLRH02-Xlh2LublDpu_3hnCqHBqNGM9yg,2256
57
57
  chatterbot/logic/unit_conversion.py,sha256=-ENMLqZqtZx0riUi0guda2oJECST0M7pZG4cSIv3ieM,5898
58
58
  chatterbot/storage/__init__.py,sha256=ADw0WQe0YKr1UIDQLaxwf0mHDnuKW_CSzgz11K4TM-4,465
59
59
  chatterbot/storage/django_storage.py,sha256=S5S4GipD7FyNJy4RWu5-S8sLPuSJIObwTtqTpnJu-ok,6159
60
60
  chatterbot/storage/mongodb.py,sha256=Ozvdvcjb3LGZxcvbSQGzwP9VloYQbmsa2FaKunFpMyU,7934
61
61
  chatterbot/storage/redis.py,sha256=FKROrzZ-7WXZ8ZoK0dKmTDdS45TxL04XOSeu0p3Jrak,12675
62
- chatterbot/storage/sql_storage.py,sha256=VVYZvclG_74IN-MrG0edc-RQ2gUO6gRQyCWWSO0MmCk,13082
62
+ chatterbot/storage/sql_storage.py,sha256=dAMLByFKQgbiTFoBUtKDeqadYRdwVO5fz1OONTcVCH4,13076
63
63
  chatterbot/storage/storage_adapter.py,sha256=fvyb-qNiB0HMJ0siVMCWUIY--6d-C47N1_kKZVFZAv4,6110
64
- chatterbot-1.2.3.dist-info/LICENSE,sha256=5b04U8mi0wp5gJMYlKi49EalnD9Q2nwY_6UEI_Avgu4,1476
65
- chatterbot-1.2.3.dist-info/METADATA,sha256=xnofLrmf6knmhcwBVcodzvxpZQ-eb4tbLB970dXQG8I,8503
66
- chatterbot-1.2.3.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
67
- chatterbot-1.2.3.dist-info/top_level.txt,sha256=W2TzAbAJ-eBXTIKZZhVlkrh87msJNmBQpyhkrHqjSrE,11
68
- chatterbot-1.2.3.dist-info/RECORD,,
64
+ chatterbot-1.2.4.dist-info/licenses/LICENSE,sha256=5b04U8mi0wp5gJMYlKi49EalnD9Q2nwY_6UEI_Avgu4,1476
65
+ chatterbot-1.2.4.dist-info/METADATA,sha256=lrGa5gxvPrNRh6fCKqr7zPvRp_qmY293ijj0ODW4uZM,7049
66
+ chatterbot-1.2.4.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
67
+ chatterbot-1.2.4.dist-info/top_level.txt,sha256=W2TzAbAJ-eBXTIKZZhVlkrh87msJNmBQpyhkrHqjSrE,11
68
+ chatterbot-1.2.4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (76.0.0)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5