ChatterBot 1.2.3__py3-none-any.whl → 1.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatterbot/__init__.py +1 -1
- chatterbot/__main__.py +15 -0
- chatterbot/chatterbot.py +55 -9
- chatterbot/comparisons.py +8 -22
- chatterbot/conversation.py +1 -2
- chatterbot/ext/django_chatterbot/abstract_models.py +16 -8
- chatterbot/ext/django_chatterbot/apps.py +7 -0
- chatterbot/ext/django_chatterbot/migrations/0020_alter_statement_conversation_and_more.py +53 -0
- chatterbot/ext/django_chatterbot/settings.py +2 -3
- chatterbot/ext/sqlalchemy_app/models.py +2 -2
- chatterbot/logic/logic_adapter.py +14 -9
- chatterbot/logic/specific_response.py +3 -7
- chatterbot/logic/time_adapter.py +3 -7
- chatterbot/preprocessors.py +4 -3
- chatterbot/response_selection.py +4 -6
- chatterbot/search.py +0 -55
- chatterbot/storage/django_storage.py +4 -1
- chatterbot/storage/sql_storage.py +1 -1
- chatterbot/tagging.py +3 -7
- chatterbot/trainers.py +306 -109
- chatterbot/utils.py +17 -28
- {chatterbot-1.2.3.dist-info → chatterbot-1.2.5.dist-info}/METADATA +8 -18
- {chatterbot-1.2.3.dist-info → chatterbot-1.2.5.dist-info}/RECORD +26 -25
- {chatterbot-1.2.3.dist-info → chatterbot-1.2.5.dist-info}/WHEEL +1 -1
- {chatterbot-1.2.3.dist-info → chatterbot-1.2.5.dist-info/licenses}/LICENSE +0 -0
- {chatterbot-1.2.3.dist-info → chatterbot-1.2.5.dist-info}/top_level.txt +0 -0
@@ -110,6 +110,9 @@ class DjangoStorageAdapter(StorageAdapter):
|
|
110
110
|
|
111
111
|
tags = kwargs.pop('tags', [])
|
112
112
|
|
113
|
+
if 'search_in_response_to' in kwargs and kwargs['search_in_response_to'] is None:
|
114
|
+
kwargs['search_in_response_to'] = ''
|
115
|
+
|
113
116
|
statement = Statement(**kwargs)
|
114
117
|
|
115
118
|
statement.save()
|
@@ -169,7 +172,7 @@ class DjangoStorageAdapter(StorageAdapter):
|
|
169
172
|
search_text=statement.search_text,
|
170
173
|
conversation=statement.conversation,
|
171
174
|
in_response_to=statement.in_response_to,
|
172
|
-
search_in_response_to=statement.search_in_response_to,
|
175
|
+
search_in_response_to=statement.search_in_response_to or '',
|
173
176
|
created_at=statement.created_at
|
174
177
|
)
|
175
178
|
|
@@ -326,7 +326,7 @@ class SQLStorageAdapter(StorageAdapter):
|
|
326
326
|
record = None
|
327
327
|
|
328
328
|
if hasattr(statement, 'id') and statement.id is not None:
|
329
|
-
record = session.
|
329
|
+
record = session.get(Statement, statement.id)
|
330
330
|
else:
|
331
331
|
record = session.query(Statement).filter(
|
332
332
|
Statement.text == statement.text,
|
chatterbot/tagging.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
|
-
from chatterbot import languages
|
1
|
+
from chatterbot import languages
|
2
|
+
from chatterbot.utils import get_model_for_language
|
2
3
|
import spacy
|
3
4
|
|
4
5
|
|
@@ -42,12 +43,7 @@ class PosLemmaTagger(object):
|
|
42
43
|
|
43
44
|
self.language = language or languages.ENG
|
44
45
|
|
45
|
-
|
46
|
-
model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[self.language]
|
47
|
-
except KeyError as e:
|
48
|
-
raise KeyError(
|
49
|
-
f'Spacy model is not available for language {self.language}'
|
50
|
-
) from e
|
46
|
+
model = get_model_for_language(self.language)
|
51
47
|
|
52
48
|
# Disable the Named Entity Recognition (NER) component because it is not necessary
|
53
49
|
self.nlp = spacy.load(model, exclude=['ner'])
|
chatterbot/trainers.py
CHANGED
@@ -25,7 +25,7 @@ class Trainer(object):
|
|
25
25
|
|
26
26
|
environment_default = bool(int(os.environ.get('CHATTERBOT_SHOW_TRAINING_PROGRESS', True)))
|
27
27
|
|
28
|
-
self.
|
28
|
+
self.disable_progress = not kwargs.get(
|
29
29
|
'show_training_progress',
|
30
30
|
environment_default
|
31
31
|
)
|
@@ -54,7 +54,7 @@ class Trainer(object):
|
|
54
54
|
def __init__(self, message=None):
|
55
55
|
default = (
|
56
56
|
'A training class must be specified before calling train(). '
|
57
|
-
'See https://docs.chatterbot.us/training
|
57
|
+
'See https://docs.chatterbot.us/training/'
|
58
58
|
)
|
59
59
|
super().__init__(message or default)
|
60
60
|
|
@@ -82,7 +82,7 @@ class ListTrainer(Trainer):
|
|
82
82
|
where the list represents a conversation.
|
83
83
|
"""
|
84
84
|
|
85
|
-
def train(self, conversation):
|
85
|
+
def train(self, conversation: list):
|
86
86
|
"""
|
87
87
|
Train the chat bot based on the provided list of
|
88
88
|
statements that represents a single conversation.
|
@@ -96,7 +96,7 @@ class ListTrainer(Trainer):
|
|
96
96
|
documents = self.chatbot.tagger.as_nlp_pipeline(conversation)
|
97
97
|
|
98
98
|
# for text in enumerate(conversation):
|
99
|
-
for document in tqdm(documents, desc='List Trainer', disable=
|
99
|
+
for document in tqdm(documents, desc='List Trainer', disable=self.disable_progress):
|
100
100
|
statement_search_text = document._.search_index
|
101
101
|
|
102
102
|
statement = self.get_preprocessed_statement(
|
@@ -135,7 +135,7 @@ class ChatterBotCorpusTrainer(Trainer):
|
|
135
135
|
for corpus, categories, _file_path in tqdm(
|
136
136
|
load_corpus(*data_file_paths),
|
137
137
|
desc='ChatterBot Corpus Trainer',
|
138
|
-
disable=
|
138
|
+
disable=self.disable_progress
|
139
139
|
):
|
140
140
|
statements_to_create = []
|
141
141
|
|
@@ -172,32 +172,268 @@ class ChatterBotCorpusTrainer(Trainer):
|
|
172
172
|
self.chatbot.storage.create_many(statements_to_create)
|
173
173
|
|
174
174
|
|
175
|
-
class
|
175
|
+
class GenericFileTrainer(Trainer):
|
176
176
|
"""
|
177
|
+
Allows the chat bot to be trained using data from a CSV or JSON file,
|
178
|
+
or directory of those file types.
|
179
|
+
"""
|
180
|
+
|
181
|
+
def __init__(self, chatbot, **kwargs):
|
182
|
+
"""
|
183
|
+
data_path: str The path to the data file or directory.
|
184
|
+
field_map: dict A dictionary containing the column name to header mapping.
|
185
|
+
"""
|
186
|
+
super().__init__(chatbot, **kwargs)
|
187
|
+
|
188
|
+
self.file_extension = None
|
189
|
+
|
190
|
+
# NOTE: If the key is an integer, this be the
|
191
|
+
# column index instead of the key or header
|
192
|
+
DEFAULT_STATEMENT_TO_HEADER_MAPPING = {
|
193
|
+
'text': 'text',
|
194
|
+
'conversation': 'conversation',
|
195
|
+
'created_at': 'created_at',
|
196
|
+
'persona': 'persona',
|
197
|
+
'tags': 'tags'
|
198
|
+
}
|
199
|
+
|
200
|
+
self.field_map = kwargs.get(
|
201
|
+
'field_map',
|
202
|
+
DEFAULT_STATEMENT_TO_HEADER_MAPPING
|
203
|
+
)
|
204
|
+
|
205
|
+
def _get_file_list(self, data_path, limit):
|
206
|
+
"""
|
207
|
+
Get a list of files to read from the data set.
|
208
|
+
"""
|
209
|
+
|
210
|
+
if self.file_extension is None:
|
211
|
+
raise self.TrainerInitializationException(
|
212
|
+
'The file_extension attribute must be set before calling train().'
|
213
|
+
)
|
214
|
+
|
215
|
+
# List all csv or json files in the specified directory
|
216
|
+
if os.path.isdir(data_path):
|
217
|
+
glob_path = os.path.join(data_path, '**', f'*.{self.file_extension}')
|
218
|
+
|
219
|
+
# Use iglob instead of glob for better performance with
|
220
|
+
# large directories because it returns an iterator
|
221
|
+
data_files = glob.iglob(glob_path, recursive=True)
|
222
|
+
|
223
|
+
for index, file_path in enumerate(data_files):
|
224
|
+
if limit is not None and index >= limit:
|
225
|
+
break
|
226
|
+
|
227
|
+
yield file_path
|
228
|
+
else:
|
229
|
+
yield data_path
|
230
|
+
|
231
|
+
def train(self, data_path: str, limit=None):
|
232
|
+
"""
|
233
|
+
Train a chatbot with data from the data file.
|
234
|
+
|
235
|
+
:param str data_path: The path to the data file or directory.
|
236
|
+
:param int limit: The maximum number of files to train from.
|
237
|
+
"""
|
238
|
+
|
239
|
+
if data_path is None:
|
240
|
+
raise self.TrainerInitializationException(
|
241
|
+
'The data_path argument must be set to the path of a file or directory.'
|
242
|
+
)
|
243
|
+
|
244
|
+
data_files = self._get_file_list(data_path, limit)
|
245
|
+
|
246
|
+
files_processed = 0
|
247
|
+
|
248
|
+
for data_file in tqdm(data_files, desc='Training', disable=self.disable_progress):
|
249
|
+
|
250
|
+
previous_statement_text = None
|
251
|
+
previous_statement_search_text = ''
|
252
|
+
|
253
|
+
file_extension = data_file.split('.')[-1].lower()
|
254
|
+
|
255
|
+
statements_to_create = []
|
256
|
+
|
257
|
+
file_abspath = os.path.abspath(data_file)
|
258
|
+
|
259
|
+
with open(file_abspath, 'r', encoding='utf-8') as file:
|
260
|
+
|
261
|
+
if self.file_extension == 'json':
|
262
|
+
data = json.load(file)
|
263
|
+
data = data['conversation']
|
264
|
+
elif file_extension == 'csv':
|
265
|
+
use_header = bool(isinstance(next(iter(self.field_map.values())), str))
|
266
|
+
|
267
|
+
if use_header:
|
268
|
+
data = csv.DictReader(file)
|
269
|
+
else:
|
270
|
+
data = csv.reader(file)
|
271
|
+
elif file_extension == 'tsv':
|
272
|
+
use_header = bool(isinstance(next(iter(self.field_map.values())), str))
|
273
|
+
|
274
|
+
if use_header:
|
275
|
+
data = csv.DictReader(file, delimiter='\t')
|
276
|
+
else:
|
277
|
+
data = csv.reader(file, delimiter='\t')
|
278
|
+
else:
|
279
|
+
self.logger.warning(f'Skipping unsupported file type: {file_extension}')
|
280
|
+
continue
|
281
|
+
|
282
|
+
files_processed += 1
|
283
|
+
|
284
|
+
text_row = self.field_map['text']
|
285
|
+
|
286
|
+
try:
|
287
|
+
documents = self.chatbot.tagger.as_nlp_pipeline([
|
288
|
+
(
|
289
|
+
row[text_row],
|
290
|
+
{
|
291
|
+
# Include any defined metadata columns
|
292
|
+
key: row[value]
|
293
|
+
for key, value in self.field_map.items()
|
294
|
+
if key != text_row
|
295
|
+
}
|
296
|
+
) for row in data if len(row) > 0
|
297
|
+
])
|
298
|
+
except KeyError as e:
|
299
|
+
raise KeyError(
|
300
|
+
f'{e}. Please check the field_map parameter used to initialize '
|
301
|
+
f'the training class and remove this value if it is not needed. '
|
302
|
+
f'Current mapping: {self.field_map}'
|
303
|
+
)
|
304
|
+
|
305
|
+
for document, context in documents:
|
306
|
+
statement = Statement(
|
307
|
+
text=document.text,
|
308
|
+
conversation=context.get('conversation', 'training'),
|
309
|
+
persona=context.get('persona', None),
|
310
|
+
tags=context.get('tags', [])
|
311
|
+
)
|
312
|
+
|
313
|
+
if 'created_at' in context:
|
314
|
+
statement.created_at = date_parser.parse(context['created_at'])
|
315
|
+
|
316
|
+
statement.search_text = document._.search_index
|
317
|
+
statement.search_in_response_to = previous_statement_search_text
|
318
|
+
|
319
|
+
# Use the in_response_to attribute for the previous statement if
|
320
|
+
# one is defined, otherwise use the last statement which was created
|
321
|
+
if 'in_response_to' in self.field_map.keys():
|
322
|
+
statement.in_response_to = context.get(self.field_map['in_response_to'], None)
|
323
|
+
else:
|
324
|
+
statement.in_response_to = previous_statement_text
|
325
|
+
|
326
|
+
for preprocessor in self.chatbot.preprocessors:
|
327
|
+
statement = preprocessor(statement)
|
328
|
+
|
329
|
+
previous_statement_text = statement.text
|
330
|
+
previous_statement_search_text = statement.search_text
|
331
|
+
|
332
|
+
statements_to_create.append(statement)
|
333
|
+
|
334
|
+
self.chatbot.storage.create_many(statements_to_create)
|
335
|
+
|
336
|
+
if files_processed:
|
337
|
+
self.chatbot.logger.info(
|
338
|
+
'Training completed. {} files were read.'.format(files_processed)
|
339
|
+
)
|
340
|
+
else:
|
341
|
+
self.chatbot.logger.warning(
|
342
|
+
'No [{}] files were detected at: {}'.format(
|
343
|
+
self.file_extension,
|
344
|
+
data_path
|
345
|
+
)
|
346
|
+
)
|
347
|
+
|
348
|
+
|
349
|
+
class CsvFileTrainer(GenericFileTrainer):
|
350
|
+
"""
|
351
|
+
.. note::
|
352
|
+
Added in version 1.2.4
|
353
|
+
|
354
|
+
Allow chatbots to be trained with data from a CSV file or
|
355
|
+
directory of CSV files.
|
356
|
+
|
357
|
+
TSV files are also supported, as long as the file_extension
|
358
|
+
parameter is set to 'tsv'.
|
359
|
+
|
360
|
+
:param str file_extension: The file extension to look for when searching for files (defaults to 'csv').
|
361
|
+
:param str field_map: A dictionary containing the database column name to header mapping.
|
362
|
+
Values can be either the header name (str) or the column index (int).
|
363
|
+
"""
|
364
|
+
|
365
|
+
def __init__(self, chatbot, **kwargs):
|
366
|
+
super().__init__(chatbot, **kwargs)
|
367
|
+
|
368
|
+
self.file_extension = kwargs.get('file_extension', 'csv')
|
369
|
+
|
370
|
+
|
371
|
+
class JsonFileTrainer(GenericFileTrainer):
|
372
|
+
"""
|
373
|
+
.. note::
|
374
|
+
Added in version 1.2.4
|
375
|
+
|
376
|
+
Allow chatbots to be trained with data from a JSON file or
|
377
|
+
directory of JSON files.
|
378
|
+
|
379
|
+
:param str field_map: A dictionary containing the database column name to header mapping.
|
380
|
+
"""
|
381
|
+
|
382
|
+
def __init__(self, chatbot, **kwargs):
|
383
|
+
super().__init__(chatbot, **kwargs)
|
384
|
+
|
385
|
+
self.file_extension = 'json'
|
386
|
+
|
387
|
+
DEFAULT_STATEMENT_TO_KEY_MAPPING = {
|
388
|
+
'text': 'text',
|
389
|
+
'conversation': 'conversation',
|
390
|
+
'created_at': 'created_at',
|
391
|
+
'in_response_to': 'in_response_to',
|
392
|
+
'persona': 'persona',
|
393
|
+
'tags': 'tags'
|
394
|
+
}
|
395
|
+
|
396
|
+
self.field_map = kwargs.get(
|
397
|
+
'field_map',
|
398
|
+
DEFAULT_STATEMENT_TO_KEY_MAPPING
|
399
|
+
)
|
400
|
+
|
401
|
+
|
402
|
+
class UbuntuCorpusTrainer(CsvFileTrainer):
|
403
|
+
"""
|
404
|
+
.. note::
|
405
|
+
PENDING DEPRECATION: Please use the ``CsvFileTrainer`` for data formats similar to this one.
|
406
|
+
|
177
407
|
Allow chatbots to be trained with the data from the Ubuntu Dialog Corpus.
|
178
408
|
|
179
409
|
For more information about the Ubuntu Dialog Corpus visit:
|
180
410
|
https://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/
|
411
|
+
|
412
|
+
:param str ubuntu_corpus_data_directory: The directory where the Ubuntu corpus data is already located, or where it should be downloaded and extracted.
|
181
413
|
"""
|
182
414
|
|
183
415
|
def __init__(self, chatbot, **kwargs):
|
184
416
|
super().__init__(chatbot, **kwargs)
|
185
417
|
home_directory = os.path.expanduser('~')
|
186
418
|
|
187
|
-
self.data_download_url =
|
188
|
-
'ubuntu_corpus_data_download_url',
|
189
|
-
'http://cs.mcgill.ca/~jpineau/datasets/ubuntu-corpus-1.0/ubuntu_dialogs.tgz'
|
190
|
-
)
|
419
|
+
self.data_download_url = None
|
191
420
|
|
192
421
|
self.data_directory = kwargs.get(
|
193
422
|
'ubuntu_corpus_data_directory',
|
194
423
|
os.path.join(home_directory, 'ubuntu_data')
|
195
424
|
)
|
196
425
|
|
197
|
-
|
426
|
+
# Directory containing extracted data
|
427
|
+
self.data_path = os.path.join(
|
198
428
|
self.data_directory, 'ubuntu_dialogs'
|
199
429
|
)
|
200
430
|
|
431
|
+
self.field_map = {
|
432
|
+
'text': 3,
|
433
|
+
'created_at': 0,
|
434
|
+
'persona': 1,
|
435
|
+
}
|
436
|
+
|
201
437
|
def is_downloaded(self, file_path):
|
202
438
|
"""
|
203
439
|
Check if the data file is already downloaded.
|
@@ -222,7 +458,6 @@ class UbuntuCorpusTrainer(Trainer):
|
|
222
458
|
"""
|
223
459
|
Download a file from the given url.
|
224
460
|
Show a progress indicator for the download status.
|
225
|
-
Based on: http://stackoverflow.com/a/15645088/1547223
|
226
461
|
"""
|
227
462
|
import requests
|
228
463
|
|
@@ -238,7 +473,8 @@ class UbuntuCorpusTrainer(Trainer):
|
|
238
473
|
return file_path
|
239
474
|
|
240
475
|
with open(file_path, 'wb') as open_file:
|
241
|
-
|
476
|
+
if show_status:
|
477
|
+
print('Downloading %s' % url)
|
242
478
|
response = requests.get(url, stream=True)
|
243
479
|
total_length = response.headers.get('content-length')
|
244
480
|
|
@@ -246,136 +482,97 @@ class UbuntuCorpusTrainer(Trainer):
|
|
246
482
|
# No content length header
|
247
483
|
open_file.write(response.content)
|
248
484
|
else:
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
485
|
+
for data in tqdm(
|
486
|
+
response.iter_content(chunk_size=4096),
|
487
|
+
desc='Downloading',
|
488
|
+
disable=not show_status
|
489
|
+
):
|
253
490
|
open_file.write(data)
|
254
|
-
if show_status:
|
255
|
-
done = int(50 * download / total_length)
|
256
|
-
sys.stdout.write('\r[%s%s]' % ('=' * done, ' ' * (50 - done)))
|
257
|
-
sys.stdout.flush()
|
258
|
-
|
259
|
-
# Add a new line after the download bar
|
260
|
-
sys.stdout.write('\n')
|
261
491
|
|
262
|
-
|
492
|
+
if show_status:
|
493
|
+
print('Download location: %s' % file_path)
|
263
494
|
return file_path
|
264
495
|
|
265
496
|
def extract(self, file_path):
|
266
497
|
"""
|
267
498
|
Extract a tar file at the specified file path.
|
268
499
|
"""
|
269
|
-
|
500
|
+
if not self.disable_progress:
|
501
|
+
print('Extracting {}'.format(file_path))
|
270
502
|
|
271
|
-
if not os.path.exists(self.
|
272
|
-
os.makedirs(self.
|
503
|
+
if not os.path.exists(self.data_path):
|
504
|
+
os.makedirs(self.data_path)
|
273
505
|
|
274
|
-
def
|
275
|
-
sys.stdout.write('.')
|
276
|
-
for member in members:
|
277
|
-
# This will be the current file being extracted
|
278
|
-
yield member
|
506
|
+
def is_within_directory(directory, target):
|
279
507
|
|
280
|
-
|
281
|
-
|
508
|
+
abs_directory = os.path.abspath(directory)
|
509
|
+
abs_target = os.path.abspath(target)
|
282
510
|
|
283
|
-
|
284
|
-
abs_target = os.path.abspath(target)
|
511
|
+
prefix = os.path.commonprefix([abs_directory, abs_target])
|
285
512
|
|
286
|
-
|
513
|
+
return prefix == abs_directory
|
287
514
|
|
288
|
-
|
515
|
+
def safe_extract(tar, path='.', members=None, *, numeric_owner=False):
|
289
516
|
|
290
|
-
|
517
|
+
for member in tar.getmembers():
|
518
|
+
member_path = os.path.join(path, member.name)
|
519
|
+
if not is_within_directory(path, member_path):
|
520
|
+
raise Exception('Attempted Path Traversal in Tar File')
|
291
521
|
|
292
|
-
|
293
|
-
member_path = os.path.join(path, member.name)
|
294
|
-
if not is_within_directory(path, member_path):
|
295
|
-
raise Exception("Attempted Path Traversal in Tar File")
|
522
|
+
tar.extractall(path, members, numeric_owner=numeric_owner)
|
296
523
|
|
297
|
-
|
524
|
+
try:
|
525
|
+
with tarfile.open(file_path, 'r') as tar:
|
526
|
+
safe_extract(tar, path=self.data_path, members=tqdm(tar, disable=self.disable_progress))
|
527
|
+
except tarfile.ReadError as e:
|
528
|
+
raise self.TrainerInitializationException(
|
529
|
+
f'The provided data file is not a valid tar file: {file_path}'
|
530
|
+
) from e
|
298
531
|
|
299
|
-
|
300
|
-
|
301
|
-
self.chatbot.logger.info('File extracted to {}'.format(self.extracted_data_directory))
|
532
|
+
self.chatbot.logger.info('File extracted to {}'.format(self.data_path))
|
302
533
|
|
303
534
|
return True
|
304
|
-
|
305
|
-
def
|
535
|
+
|
536
|
+
def _get_file_list(self, data_path, limit):
|
306
537
|
"""
|
307
|
-
|
538
|
+
Get a list of files to read from the data set.
|
308
539
|
"""
|
540
|
+
|
541
|
+
if self.data_download_url is None:
|
542
|
+
raise self.TrainerInitializationException(
|
543
|
+
'The data_download_url attribute must be set before calling train().'
|
544
|
+
)
|
545
|
+
|
309
546
|
# Download and extract the Ubuntu dialog corpus if needed
|
310
547
|
corpus_download_path = self.download(self.data_download_url)
|
311
548
|
|
312
549
|
# Extract if the directory does not already exist
|
313
|
-
if not self.is_extracted(
|
550
|
+
if not self.is_extracted(data_path):
|
314
551
|
self.extract(corpus_download_path)
|
315
552
|
|
316
553
|
extracted_corpus_path = os.path.join(
|
317
|
-
|
318
|
-
'**', '**', '*.tsv'
|
554
|
+
data_path, '**', '**', '*.tsv'
|
319
555
|
)
|
320
556
|
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
yield items[start_index:end_index]
|
557
|
+
# Use iglob instead of glob for better performance with
|
558
|
+
# large directories because it returns an iterator
|
559
|
+
data_files = glob.iglob(extracted_corpus_path)
|
325
560
|
|
326
|
-
|
561
|
+
for index, file_path in enumerate(data_files):
|
562
|
+
if limit is not None and index >= limit:
|
563
|
+
break
|
327
564
|
|
328
|
-
|
329
|
-
if limit is not None:
|
330
|
-
file_list = file_list[:limit]
|
565
|
+
yield file_path
|
331
566
|
|
332
|
-
|
567
|
+
def train(self, data_download_url, limit=None):
|
568
|
+
"""
|
569
|
+
:param str data_download_url: The URL to download the Ubuntu dialog corpus from.
|
570
|
+
:param int limit: The maximum number of files to train from.
|
571
|
+
"""
|
572
|
+
self.data_download_url = data_download_url
|
333
573
|
|
334
574
|
start_time = time.time()
|
575
|
+
super().train(self.data_path, limit=limit)
|
335
576
|
|
336
|
-
|
337
|
-
|
338
|
-
statements_from_file = []
|
339
|
-
|
340
|
-
for tsv_file in tqdm(tsv_files, desc=f'Training with batch {batch_number} of {len(file_groups)}'):
|
341
|
-
with open(tsv_file, 'r', encoding='utf-8') as tsv:
|
342
|
-
reader = csv.reader(tsv, delimiter='\t')
|
343
|
-
|
344
|
-
previous_statement_text = None
|
345
|
-
previous_statement_search_text = ''
|
346
|
-
|
347
|
-
documents = self.chatbot.tagger.as_nlp_pipeline([
|
348
|
-
(
|
349
|
-
row[3],
|
350
|
-
{
|
351
|
-
'persona': row[1],
|
352
|
-
'created_at': row[0],
|
353
|
-
}
|
354
|
-
) for row in reader if len(row) > 0
|
355
|
-
])
|
356
|
-
|
357
|
-
for document, context in documents:
|
358
|
-
|
359
|
-
statement_search_text = document._.search_index
|
360
|
-
|
361
|
-
statement = Statement(
|
362
|
-
text=document.text,
|
363
|
-
in_response_to=previous_statement_text,
|
364
|
-
conversation='training',
|
365
|
-
created_at=date_parser.parse(context['created_at']),
|
366
|
-
persona=context['persona'],
|
367
|
-
search_text=statement_search_text,
|
368
|
-
search_in_response_to=previous_statement_search_text
|
369
|
-
)
|
370
|
-
|
371
|
-
for preprocessor in self.chatbot.preprocessors:
|
372
|
-
statement = preprocessor(statement)
|
373
|
-
|
374
|
-
previous_statement_text = statement.text
|
375
|
-
previous_statement_search_text = statement_search_text
|
376
|
-
|
377
|
-
statements_from_file.append(statement)
|
378
|
-
|
379
|
-
self.chatbot.storage.create_many(statements_from_file)
|
380
|
-
|
381
|
-
print('Training took', time.time() - start_time, 'seconds.')
|
577
|
+
if not self.disable_progress:
|
578
|
+
print('Training took', time.time() - start_time, 'seconds.')
|
chatterbot/utils.py
CHANGED
@@ -3,7 +3,6 @@ ChatterBot utility functions
|
|
3
3
|
"""
|
4
4
|
import importlib
|
5
5
|
import time
|
6
|
-
import sys
|
7
6
|
|
8
7
|
|
9
8
|
def import_module(dotted_path):
|
@@ -71,7 +70,7 @@ def validate_adapter_class(validate_class, adapter_class):
|
|
71
70
|
)
|
72
71
|
|
73
72
|
|
74
|
-
def get_response_time(chatbot, statement='Hello'):
|
73
|
+
def get_response_time(chatbot, statement='Hello') -> float:
|
75
74
|
"""
|
76
75
|
Returns the amount of time taken for a given
|
77
76
|
chat bot to return a response.
|
@@ -80,7 +79,6 @@ def get_response_time(chatbot, statement='Hello'):
|
|
80
79
|
:type chatbot: ChatBot
|
81
80
|
|
82
81
|
:returns: The response time in seconds.
|
83
|
-
:rtype: float
|
84
82
|
"""
|
85
83
|
start_time = time.time()
|
86
84
|
|
@@ -89,30 +87,21 @@ def get_response_time(chatbot, statement='Hello'):
|
|
89
87
|
return time.time() - start_time
|
90
88
|
|
91
89
|
|
92
|
-
def
|
90
|
+
def get_model_for_language(language):
|
93
91
|
"""
|
94
|
-
|
95
|
-
:param description: Training description
|
96
|
-
:type description: str
|
97
|
-
|
98
|
-
:param iteration_counter: Incremental counter
|
99
|
-
:type iteration_counter: int
|
100
|
-
|
101
|
-
:param total_items: total number items
|
102
|
-
:type total_items: int
|
103
|
-
|
104
|
-
:param progress_bar_length: Progress bar length
|
105
|
-
:type progress_bar_length: int
|
106
|
-
|
107
|
-
:returns: void
|
108
|
-
:rtype: void
|
109
|
-
|
110
|
-
DEPRECTTED: use `tqdm` instead
|
92
|
+
Returns the spacy model for the specified language.
|
111
93
|
"""
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
94
|
+
from chatterbot import constants
|
95
|
+
|
96
|
+
try:
|
97
|
+
model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[language]
|
98
|
+
except KeyError as e:
|
99
|
+
if hasattr(language, 'ENGLISH_NAME'):
|
100
|
+
language_name = language.ENGLISH_NAME
|
101
|
+
else:
|
102
|
+
language_name = language
|
103
|
+
raise KeyError(
|
104
|
+
f'A corresponding spacy model for "{language_name}" could not be found.'
|
105
|
+
) from e
|
106
|
+
|
107
|
+
return model
|