ChatterBot 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatterbot/__init__.py +1 -1
- chatterbot/__main__.py +15 -0
- chatterbot/chatterbot.py +52 -6
- chatterbot/comparisons.py +3 -13
- chatterbot/ext/sqlalchemy_app/models.py +2 -2
- chatterbot/logic/specific_response.py +48 -9
- chatterbot/logic/time_adapter.py +3 -7
- chatterbot/logic/unit_conversion.py +4 -3
- chatterbot/storage/__init__.py +2 -0
- chatterbot/storage/redis.py +390 -0
- chatterbot/storage/sql_storage.py +1 -1
- chatterbot/tagging.py +3 -7
- chatterbot/trainers.py +297 -109
- chatterbot/utils.py +16 -25
- chatterbot/vectorstores.py +74 -0
- {chatterbot-1.2.2.dist-info → chatterbot-1.2.4.dist-info}/METADATA +16 -20
- {chatterbot-1.2.2.dist-info → chatterbot-1.2.4.dist-info}/RECORD +20 -18
- {chatterbot-1.2.2.dist-info → chatterbot-1.2.4.dist-info}/WHEEL +1 -1
- {chatterbot-1.2.2.dist-info → chatterbot-1.2.4.dist-info/licenses}/LICENSE +0 -0
- {chatterbot-1.2.2.dist-info → chatterbot-1.2.4.dist-info}/top_level.txt +0 -0
@@ -326,7 +326,7 @@ class SQLStorageAdapter(StorageAdapter):
|
|
326
326
|
record = None
|
327
327
|
|
328
328
|
if hasattr(statement, 'id') and statement.id is not None:
|
329
|
-
record = session.
|
329
|
+
record = session.get(Statement, statement.id)
|
330
330
|
else:
|
331
331
|
record = session.query(Statement).filter(
|
332
332
|
Statement.text == statement.text,
|
chatterbot/tagging.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
|
-
from chatterbot import languages
|
1
|
+
from chatterbot import languages
|
2
|
+
from chatterbot.utils import get_model_for_language
|
2
3
|
import spacy
|
3
4
|
|
4
5
|
|
@@ -42,12 +43,7 @@ class PosLemmaTagger(object):
|
|
42
43
|
|
43
44
|
self.language = language or languages.ENG
|
44
45
|
|
45
|
-
|
46
|
-
model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[self.language]
|
47
|
-
except KeyError as e:
|
48
|
-
raise KeyError(
|
49
|
-
f'Spacy model is not available for language {self.language}'
|
50
|
-
) from e
|
46
|
+
model = get_model_for_language(self.language)
|
51
47
|
|
52
48
|
# Disable the Named Entity Recognition (NER) component because it is not necessary
|
53
49
|
self.nlp = spacy.load(model, exclude=['ner'])
|
chatterbot/trainers.py
CHANGED
@@ -25,7 +25,7 @@ class Trainer(object):
|
|
25
25
|
|
26
26
|
environment_default = bool(int(os.environ.get('CHATTERBOT_SHOW_TRAINING_PROGRESS', True)))
|
27
27
|
|
28
|
-
self.
|
28
|
+
self.disable_progress = not kwargs.get(
|
29
29
|
'show_training_progress',
|
30
30
|
environment_default
|
31
31
|
)
|
@@ -54,7 +54,7 @@ class Trainer(object):
|
|
54
54
|
def __init__(self, message=None):
|
55
55
|
default = (
|
56
56
|
'A training class must be specified before calling train(). '
|
57
|
-
'See https://docs.chatterbot.us/training
|
57
|
+
'See https://docs.chatterbot.us/training/'
|
58
58
|
)
|
59
59
|
super().__init__(message or default)
|
60
60
|
|
@@ -82,7 +82,7 @@ class ListTrainer(Trainer):
|
|
82
82
|
where the list represents a conversation.
|
83
83
|
"""
|
84
84
|
|
85
|
-
def train(self, conversation):
|
85
|
+
def train(self, conversation: list):
|
86
86
|
"""
|
87
87
|
Train the chat bot based on the provided list of
|
88
88
|
statements that represents a single conversation.
|
@@ -96,7 +96,7 @@ class ListTrainer(Trainer):
|
|
96
96
|
documents = self.chatbot.tagger.as_nlp_pipeline(conversation)
|
97
97
|
|
98
98
|
# for text in enumerate(conversation):
|
99
|
-
for document in tqdm(documents, desc='List Trainer', disable=
|
99
|
+
for document in tqdm(documents, desc='List Trainer', disable=self.disable_progress):
|
100
100
|
statement_search_text = document._.search_index
|
101
101
|
|
102
102
|
statement = self.get_preprocessed_statement(
|
@@ -135,7 +135,7 @@ class ChatterBotCorpusTrainer(Trainer):
|
|
135
135
|
for corpus, categories, _file_path in tqdm(
|
136
136
|
load_corpus(*data_file_paths),
|
137
137
|
desc='ChatterBot Corpus Trainer',
|
138
|
-
disable=
|
138
|
+
disable=self.disable_progress
|
139
139
|
):
|
140
140
|
statements_to_create = []
|
141
141
|
|
@@ -172,32 +172,259 @@ class ChatterBotCorpusTrainer(Trainer):
|
|
172
172
|
self.chatbot.storage.create_many(statements_to_create)
|
173
173
|
|
174
174
|
|
175
|
-
class
|
175
|
+
class GenericFileTrainer(Trainer):
|
176
176
|
"""
|
177
|
+
Allows the chat bot to be trained using data from a CSV or JSON file,
|
178
|
+
or directory of those file types.
|
179
|
+
"""
|
180
|
+
|
181
|
+
def __init__(self, chatbot, **kwargs):
|
182
|
+
"""
|
183
|
+
data_path: str The path to the data file or directory.
|
184
|
+
field_map: dict A dictionary containing the column name to header mapping.
|
185
|
+
"""
|
186
|
+
super().__init__(chatbot, **kwargs)
|
187
|
+
|
188
|
+
self.file_extension = None
|
189
|
+
|
190
|
+
# NOTE: If the key is an integer, this be the
|
191
|
+
# column index instead of the key or header
|
192
|
+
DEFAULT_STATEMENT_TO_HEADER_MAPPING = {
|
193
|
+
'text': 'text',
|
194
|
+
'conversation': 'conversation',
|
195
|
+
'created_at': 'created_at',
|
196
|
+
'persona': 'persona',
|
197
|
+
'tags': 'tags'
|
198
|
+
}
|
199
|
+
|
200
|
+
self.field_map = kwargs.get(
|
201
|
+
'field_map',
|
202
|
+
DEFAULT_STATEMENT_TO_HEADER_MAPPING
|
203
|
+
)
|
204
|
+
|
205
|
+
def _get_file_list(self, data_path, limit):
|
206
|
+
"""
|
207
|
+
Get a list of files to read from the data set.
|
208
|
+
"""
|
209
|
+
|
210
|
+
if self.file_extension is None:
|
211
|
+
raise self.TrainerInitializationException(
|
212
|
+
'The file_extension attribute must be set before calling train().'
|
213
|
+
)
|
214
|
+
|
215
|
+
# List all csv or json files in the specified directory
|
216
|
+
if os.path.isdir(data_path):
|
217
|
+
glob_path = os.path.join(data_path, '**', f'*.{self.file_extension}')
|
218
|
+
|
219
|
+
# Use iglob instead of glob for better performance with
|
220
|
+
# large directories because it returns an iterator
|
221
|
+
data_files = glob.iglob(glob_path, recursive=True)
|
222
|
+
|
223
|
+
for index, file_path in enumerate(data_files):
|
224
|
+
if limit is not None and index >= limit:
|
225
|
+
break
|
226
|
+
|
227
|
+
yield file_path
|
228
|
+
else:
|
229
|
+
return [data_path]
|
230
|
+
|
231
|
+
def train(self, data_path: str, limit=None):
|
232
|
+
"""
|
233
|
+
Train a chatbot with data from the data file.
|
234
|
+
|
235
|
+
:param str data_path: The path to the data file or directory.
|
236
|
+
:param int limit: The maximum number of files to train from.
|
237
|
+
"""
|
238
|
+
|
239
|
+
if data_path is None:
|
240
|
+
raise self.TrainerInitializationException(
|
241
|
+
'The data_path argument must be set to the path of a file or directory.'
|
242
|
+
)
|
243
|
+
|
244
|
+
data_files = self._get_file_list(data_path, limit)
|
245
|
+
|
246
|
+
files_processed = 0
|
247
|
+
|
248
|
+
for data_file in tqdm(data_files, desc='Training', disable=self.disable_progress):
|
249
|
+
|
250
|
+
previous_statement_text = None
|
251
|
+
previous_statement_search_text = ''
|
252
|
+
|
253
|
+
file_extension = data_file.split('.')[-1].lower()
|
254
|
+
|
255
|
+
statements_to_create = []
|
256
|
+
|
257
|
+
with open(data_file, 'r', encoding='utf-8') as file:
|
258
|
+
|
259
|
+
if self.file_extension == 'json':
|
260
|
+
data = json.load(file)
|
261
|
+
data = data['conversation']
|
262
|
+
elif file_extension == 'csv':
|
263
|
+
use_header = bool(isinstance(next(iter(self.field_map.values())), str))
|
264
|
+
|
265
|
+
if use_header:
|
266
|
+
data = csv.DictReader(file)
|
267
|
+
else:
|
268
|
+
data = csv.reader(file)
|
269
|
+
elif file_extension == 'tsv':
|
270
|
+
use_header = bool(isinstance(next(iter(self.field_map.values())), str))
|
271
|
+
|
272
|
+
if use_header:
|
273
|
+
data = csv.DictReader(file, delimiter='\t')
|
274
|
+
else:
|
275
|
+
data = csv.reader(file, delimiter='\t')
|
276
|
+
else:
|
277
|
+
self.logger.warning(f'Skipping unsupported file type: {file_extension}')
|
278
|
+
continue
|
279
|
+
|
280
|
+
files_processed += 1
|
281
|
+
|
282
|
+
text_row = self.field_map['text']
|
283
|
+
|
284
|
+
documents = self.chatbot.tagger.as_nlp_pipeline([
|
285
|
+
(
|
286
|
+
row[text_row],
|
287
|
+
{
|
288
|
+
# Include any defined metadata columns
|
289
|
+
key: row[value]
|
290
|
+
for key, value in self.field_map.items()
|
291
|
+
if key != text_row
|
292
|
+
}
|
293
|
+
) for row in data if len(row) > 0
|
294
|
+
])
|
295
|
+
|
296
|
+
for document, context in documents:
|
297
|
+
statement = Statement(
|
298
|
+
text=document.text,
|
299
|
+
conversation=context.get('conversation', 'training'),
|
300
|
+
persona=context.get('persona', None),
|
301
|
+
tags=context.get('tags', [])
|
302
|
+
)
|
303
|
+
|
304
|
+
if 'created_at' in context:
|
305
|
+
statement.created_at = date_parser.parse(context['created_at'])
|
306
|
+
|
307
|
+
statement.search_text = document._.search_index
|
308
|
+
statement.search_in_response_to = previous_statement_search_text
|
309
|
+
|
310
|
+
# Use the in_response_to attribute for the previous statement if
|
311
|
+
# one is defined, otherwise use the last statement which was created
|
312
|
+
if 'in_response_to' in self.field_map.keys():
|
313
|
+
statement.in_response_to = context.get(self.field_map['in_response_to'], None)
|
314
|
+
else:
|
315
|
+
statement.in_response_to = previous_statement_text
|
316
|
+
|
317
|
+
for preprocessor in self.chatbot.preprocessors:
|
318
|
+
statement = preprocessor(statement)
|
319
|
+
|
320
|
+
previous_statement_text = statement.text
|
321
|
+
previous_statement_search_text = statement.search_text
|
322
|
+
|
323
|
+
statements_to_create.append(statement)
|
324
|
+
|
325
|
+
self.chatbot.storage.create_many(statements_to_create)
|
326
|
+
|
327
|
+
if files_processed:
|
328
|
+
self.chatbot.logger.info(
|
329
|
+
'Training completed. {} files were read.'.format(files_processed)
|
330
|
+
)
|
331
|
+
else:
|
332
|
+
self.chatbot.logger.warning(
|
333
|
+
'No [{}] files were detected at: {}'.format(
|
334
|
+
self.file_extension,
|
335
|
+
data_path
|
336
|
+
)
|
337
|
+
)
|
338
|
+
|
339
|
+
|
340
|
+
class CsvFileTrainer(GenericFileTrainer):
|
341
|
+
"""
|
342
|
+
.. note::
|
343
|
+
Added in version 1.2.4
|
344
|
+
|
345
|
+
Allow chatbots to be trained with data from a CSV file or
|
346
|
+
directory of CSV files.
|
347
|
+
|
348
|
+
TSV files are also supported, as long as the file_extension
|
349
|
+
parameter is set to 'tsv'.
|
350
|
+
|
351
|
+
:param str file_extension: The file extension to look for when searching for files (defaults to 'csv').
|
352
|
+
:param str field_map: A dictionary containing the database column name to header mapping.
|
353
|
+
Values can be either the header name (str) or the column index (int).
|
354
|
+
"""
|
355
|
+
|
356
|
+
def __init__(self, chatbot, **kwargs):
|
357
|
+
super().__init__(chatbot, **kwargs)
|
358
|
+
|
359
|
+
self.file_extension = kwargs.get('file_extension', 'csv')
|
360
|
+
|
361
|
+
|
362
|
+
class JsonFileTrainer(GenericFileTrainer):
|
363
|
+
"""
|
364
|
+
.. note::
|
365
|
+
Added in version 1.2.4
|
366
|
+
|
367
|
+
Allow chatbots to be trained with data from a JSON file or
|
368
|
+
directory of JSON files.
|
369
|
+
|
370
|
+
:param str field_map: A dictionary containing the database column name to header mapping.
|
371
|
+
"""
|
372
|
+
|
373
|
+
def __init__(self, chatbot, **kwargs):
|
374
|
+
super().__init__(chatbot, **kwargs)
|
375
|
+
|
376
|
+
self.file_extension = 'json'
|
377
|
+
|
378
|
+
DEFAULT_STATEMENT_TO_KEY_MAPPING = {
|
379
|
+
'text': 'text',
|
380
|
+
'conversation': 'conversation',
|
381
|
+
'created_at': 'created_at',
|
382
|
+
'in_response_to': 'in_response_to',
|
383
|
+
'persona': 'persona',
|
384
|
+
'tags': 'tags'
|
385
|
+
}
|
386
|
+
|
387
|
+
self.field_map = kwargs.get(
|
388
|
+
'field_map',
|
389
|
+
DEFAULT_STATEMENT_TO_KEY_MAPPING
|
390
|
+
)
|
391
|
+
|
392
|
+
|
393
|
+
class UbuntuCorpusTrainer(CsvFileTrainer):
|
394
|
+
"""
|
395
|
+
.. note::
|
396
|
+
PENDING DEPRECATION: Please use the ``CsvFileTrainer`` for data formats similar to this one.
|
397
|
+
|
177
398
|
Allow chatbots to be trained with the data from the Ubuntu Dialog Corpus.
|
178
399
|
|
179
400
|
For more information about the Ubuntu Dialog Corpus visit:
|
180
401
|
https://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/
|
402
|
+
|
403
|
+
:param str ubuntu_corpus_data_directory: The directory where the Ubuntu corpus data is already located, or where it should be downloaded and extracted.
|
181
404
|
"""
|
182
405
|
|
183
406
|
def __init__(self, chatbot, **kwargs):
|
184
407
|
super().__init__(chatbot, **kwargs)
|
185
408
|
home_directory = os.path.expanduser('~')
|
186
409
|
|
187
|
-
self.data_download_url =
|
188
|
-
'ubuntu_corpus_data_download_url',
|
189
|
-
'http://cs.mcgill.ca/~jpineau/datasets/ubuntu-corpus-1.0/ubuntu_dialogs.tgz'
|
190
|
-
)
|
410
|
+
self.data_download_url = None
|
191
411
|
|
192
412
|
self.data_directory = kwargs.get(
|
193
413
|
'ubuntu_corpus_data_directory',
|
194
414
|
os.path.join(home_directory, 'ubuntu_data')
|
195
415
|
)
|
196
416
|
|
197
|
-
|
417
|
+
# Directory containing extracted data
|
418
|
+
self.data_path = os.path.join(
|
198
419
|
self.data_directory, 'ubuntu_dialogs'
|
199
420
|
)
|
200
421
|
|
422
|
+
self.field_map = {
|
423
|
+
'text': 3,
|
424
|
+
'created_at': 0,
|
425
|
+
'persona': 1,
|
426
|
+
}
|
427
|
+
|
201
428
|
def is_downloaded(self, file_path):
|
202
429
|
"""
|
203
430
|
Check if the data file is already downloaded.
|
@@ -222,7 +449,6 @@ class UbuntuCorpusTrainer(Trainer):
|
|
222
449
|
"""
|
223
450
|
Download a file from the given url.
|
224
451
|
Show a progress indicator for the download status.
|
225
|
-
Based on: http://stackoverflow.com/a/15645088/1547223
|
226
452
|
"""
|
227
453
|
import requests
|
228
454
|
|
@@ -238,7 +464,8 @@ class UbuntuCorpusTrainer(Trainer):
|
|
238
464
|
return file_path
|
239
465
|
|
240
466
|
with open(file_path, 'wb') as open_file:
|
241
|
-
|
467
|
+
if show_status:
|
468
|
+
print('Downloading %s' % url)
|
242
469
|
response = requests.get(url, stream=True)
|
243
470
|
total_length = response.headers.get('content-length')
|
244
471
|
|
@@ -246,136 +473,97 @@ class UbuntuCorpusTrainer(Trainer):
|
|
246
473
|
# No content length header
|
247
474
|
open_file.write(response.content)
|
248
475
|
else:
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
476
|
+
for data in tqdm(
|
477
|
+
response.iter_content(chunk_size=4096),
|
478
|
+
desc='Downloading',
|
479
|
+
disable=not show_status
|
480
|
+
):
|
253
481
|
open_file.write(data)
|
254
|
-
if show_status:
|
255
|
-
done = int(50 * download / total_length)
|
256
|
-
sys.stdout.write('\r[%s%s]' % ('=' * done, ' ' * (50 - done)))
|
257
|
-
sys.stdout.flush()
|
258
482
|
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
print('Download location: %s' % file_path)
|
483
|
+
if show_status:
|
484
|
+
print('Download location: %s' % file_path)
|
263
485
|
return file_path
|
264
486
|
|
265
487
|
def extract(self, file_path):
|
266
488
|
"""
|
267
489
|
Extract a tar file at the specified file path.
|
268
490
|
"""
|
269
|
-
|
270
|
-
|
271
|
-
if not os.path.exists(self.extracted_data_directory):
|
272
|
-
os.makedirs(self.extracted_data_directory)
|
491
|
+
if not self.disable_progress:
|
492
|
+
print('Extracting {}'.format(file_path))
|
273
493
|
|
274
|
-
|
275
|
-
|
276
|
-
for member in members:
|
277
|
-
# This will be the current file being extracted
|
278
|
-
yield member
|
494
|
+
if not os.path.exists(self.data_path):
|
495
|
+
os.makedirs(self.data_path)
|
279
496
|
|
280
|
-
|
281
|
-
def is_within_directory(directory, target):
|
497
|
+
def is_within_directory(directory, target):
|
282
498
|
|
283
|
-
|
284
|
-
|
499
|
+
abs_directory = os.path.abspath(directory)
|
500
|
+
abs_target = os.path.abspath(target)
|
285
501
|
|
286
|
-
|
502
|
+
prefix = os.path.commonprefix([abs_directory, abs_target])
|
287
503
|
|
288
|
-
|
504
|
+
return prefix == abs_directory
|
289
505
|
|
290
|
-
|
506
|
+
def safe_extract(tar, path='.', members=None, *, numeric_owner=False):
|
291
507
|
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
508
|
+
for member in tar.getmembers():
|
509
|
+
member_path = os.path.join(path, member.name)
|
510
|
+
if not is_within_directory(path, member_path):
|
511
|
+
raise Exception('Attempted Path Traversal in Tar File')
|
296
512
|
|
297
|
-
|
513
|
+
tar.extractall(path, members, numeric_owner=numeric_owner)
|
298
514
|
|
299
|
-
|
515
|
+
try:
|
516
|
+
with tarfile.open(file_path, 'r') as tar:
|
517
|
+
safe_extract(tar, path=self.data_path, members=tqdm(tar, disable=self.disable_progress))
|
518
|
+
except tarfile.ReadError as e:
|
519
|
+
raise self.TrainerInitializationException(
|
520
|
+
f'The provided data file is not a valid tar file: {file_path}'
|
521
|
+
) from e
|
300
522
|
|
301
|
-
self.chatbot.logger.info('File extracted to {}'.format(self.
|
523
|
+
self.chatbot.logger.info('File extracted to {}'.format(self.data_path))
|
302
524
|
|
303
525
|
return True
|
304
|
-
|
305
|
-
def
|
526
|
+
|
527
|
+
def _get_file_list(self, data_path, limit):
|
306
528
|
"""
|
307
|
-
|
529
|
+
Get a list of files to read from the data set.
|
308
530
|
"""
|
531
|
+
|
532
|
+
if self.data_download_url is None:
|
533
|
+
raise self.TrainerInitializationException(
|
534
|
+
'The data_download_url attribute must be set before calling train().'
|
535
|
+
)
|
536
|
+
|
309
537
|
# Download and extract the Ubuntu dialog corpus if needed
|
310
538
|
corpus_download_path = self.download(self.data_download_url)
|
311
539
|
|
312
540
|
# Extract if the directory does not already exist
|
313
|
-
if not self.is_extracted(
|
541
|
+
if not self.is_extracted(data_path):
|
314
542
|
self.extract(corpus_download_path)
|
315
543
|
|
316
544
|
extracted_corpus_path = os.path.join(
|
317
|
-
|
318
|
-
'**', '**', '*.tsv'
|
545
|
+
data_path, '**', '**', '*.tsv'
|
319
546
|
)
|
320
547
|
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
yield items[start_index:end_index]
|
548
|
+
# Use iglob instead of glob for better performance with
|
549
|
+
# large directories because it returns an iterator
|
550
|
+
data_files = glob.iglob(extracted_corpus_path)
|
325
551
|
|
326
|
-
|
552
|
+
for index, file_path in enumerate(data_files):
|
553
|
+
if limit is not None and index >= limit:
|
554
|
+
break
|
327
555
|
|
328
|
-
|
329
|
-
if limit is not None:
|
330
|
-
file_list = file_list[:limit]
|
556
|
+
yield file_path
|
331
557
|
|
332
|
-
|
558
|
+
def train(self, data_download_url, limit=None):
|
559
|
+
"""
|
560
|
+
:param str data_download_url: The URL to download the Ubuntu dialog corpus from.
|
561
|
+
:param int limit: The maximum number of files to train from.
|
562
|
+
"""
|
563
|
+
self.data_download_url = data_download_url
|
333
564
|
|
334
565
|
start_time = time.time()
|
566
|
+
super().train(self.data_path, limit=limit)
|
335
567
|
|
336
|
-
|
337
|
-
|
338
|
-
statements_from_file = []
|
339
|
-
|
340
|
-
for tsv_file in tqdm(tsv_files, desc=f'Training with batch {batch_number} of {len(file_groups)}'):
|
341
|
-
with open(tsv_file, 'r', encoding='utf-8') as tsv:
|
342
|
-
reader = csv.reader(tsv, delimiter='\t')
|
343
|
-
|
344
|
-
previous_statement_text = None
|
345
|
-
previous_statement_search_text = ''
|
346
|
-
|
347
|
-
documents = self.chatbot.tagger.as_nlp_pipeline([
|
348
|
-
(
|
349
|
-
row[3],
|
350
|
-
{
|
351
|
-
'persona': row[1],
|
352
|
-
'created_at': row[0],
|
353
|
-
}
|
354
|
-
) for row in reader if len(row) > 0
|
355
|
-
])
|
356
|
-
|
357
|
-
for document, context in documents:
|
358
|
-
|
359
|
-
statement_search_text = document._.search_index
|
360
|
-
|
361
|
-
statement = Statement(
|
362
|
-
text=document.text,
|
363
|
-
in_response_to=previous_statement_text,
|
364
|
-
conversation='training',
|
365
|
-
created_at=date_parser.parse(context['created_at']),
|
366
|
-
persona=context['persona'],
|
367
|
-
search_text=statement_search_text,
|
368
|
-
search_in_response_to=previous_statement_search_text
|
369
|
-
)
|
370
|
-
|
371
|
-
for preprocessor in self.chatbot.preprocessors:
|
372
|
-
statement = preprocessor(statement)
|
373
|
-
|
374
|
-
previous_statement_text = statement.text
|
375
|
-
previous_statement_search_text = statement_search_text
|
376
|
-
|
377
|
-
statements_from_file.append(statement)
|
378
|
-
|
379
|
-
self.chatbot.storage.create_many(statements_from_file)
|
380
|
-
|
381
|
-
print('Training took', time.time() - start_time, 'seconds.')
|
568
|
+
if not self.disable_progress:
|
569
|
+
print('Training took', time.time() - start_time, 'seconds.')
|
chatterbot/utils.py
CHANGED
@@ -89,30 +89,21 @@ def get_response_time(chatbot, statement='Hello'):
|
|
89
89
|
return time.time() - start_time
|
90
90
|
|
91
91
|
|
92
|
-
def
|
92
|
+
def get_model_for_language(language):
|
93
93
|
"""
|
94
|
-
|
95
|
-
:param description: Training description
|
96
|
-
:type description: str
|
97
|
-
|
98
|
-
:param iteration_counter: Incremental counter
|
99
|
-
:type iteration_counter: int
|
100
|
-
|
101
|
-
:param total_items: total number items
|
102
|
-
:type total_items: int
|
103
|
-
|
104
|
-
:param progress_bar_length: Progress bar length
|
105
|
-
:type progress_bar_length: int
|
106
|
-
|
107
|
-
:returns: void
|
108
|
-
:rtype: void
|
109
|
-
|
110
|
-
DEPRECTTED: use `tqdm` instead
|
94
|
+
Returns the spacy model for the specified language.
|
111
95
|
"""
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
96
|
+
from chatterbot import constants
|
97
|
+
|
98
|
+
try:
|
99
|
+
model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[language]
|
100
|
+
except KeyError as e:
|
101
|
+
if hasattr(language, 'ENGLISH_NAME'):
|
102
|
+
language_name = language.ENGLISH_NAME
|
103
|
+
else:
|
104
|
+
language_name = language
|
105
|
+
raise KeyError(
|
106
|
+
f'A corresponding spacy model for "{language_name}" could not be found.'
|
107
|
+
) from e
|
108
|
+
|
109
|
+
return model
|
@@ -0,0 +1,74 @@
|
|
1
|
+
"""
|
2
|
+
Redis vector store.
|
3
|
+
"""
|
4
|
+
from __future__ import annotations
|
5
|
+
|
6
|
+
from typing import Any, List, Sequence
|
7
|
+
|
8
|
+
from langchain_core.documents import Document
|
9
|
+
from redisvl.redis.utils import convert_bytes
|
10
|
+
from redisvl.query import FilterQuery
|
11
|
+
|
12
|
+
from langchain_core.documents import Document
|
13
|
+
from langchain_redis.vectorstores import RedisVectorStore as LangChainRedisVectorStore
|
14
|
+
|
15
|
+
|
16
|
+
class RedisVectorStore(LangChainRedisVectorStore):
|
17
|
+
"""
|
18
|
+
Redis vector store integration.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def query_search(
|
22
|
+
self,
|
23
|
+
k=4,
|
24
|
+
filter=None,
|
25
|
+
sort_by=None,
|
26
|
+
) -> List[Document]:
|
27
|
+
"""
|
28
|
+
Return docs based on the provided query.
|
29
|
+
|
30
|
+
k: int, default=4
|
31
|
+
Number of documents to return.
|
32
|
+
filter: str, default=None
|
33
|
+
A filter expression to apply to the query.
|
34
|
+
sort_by: str, default=None
|
35
|
+
A field to sort the results by.
|
36
|
+
|
37
|
+
returns:
|
38
|
+
A list of Documents most matching the query.
|
39
|
+
"""
|
40
|
+
from chatterbot import ChatBot
|
41
|
+
|
42
|
+
return_fields = [
|
43
|
+
self.config.content_field
|
44
|
+
]
|
45
|
+
return_fields += [
|
46
|
+
field.name
|
47
|
+
for field in self._index.schema.fields.values()
|
48
|
+
if field.name
|
49
|
+
not in [self.config.embedding_field, self.config.content_field]
|
50
|
+
]
|
51
|
+
|
52
|
+
query = FilterQuery(
|
53
|
+
return_fields=return_fields,
|
54
|
+
num_results=k,
|
55
|
+
filter_expression=filter,
|
56
|
+
sort_by=sort_by,
|
57
|
+
)
|
58
|
+
|
59
|
+
try:
|
60
|
+
results = self._index.query(query)
|
61
|
+
except Exception as e:
|
62
|
+
raise ChatBot.ChatBotException(f'Error querying index: {query}') from e
|
63
|
+
|
64
|
+
if results:
|
65
|
+
with self._index.client.pipeline(transaction=False) as pipe:
|
66
|
+
for document in results:
|
67
|
+
pipe.hgetall(document['id'])
|
68
|
+
full_documents = convert_bytes(pipe.execute())
|
69
|
+
else:
|
70
|
+
full_documents = []
|
71
|
+
|
72
|
+
return self._prepare_docs_full(
|
73
|
+
True, results, full_documents, True
|
74
|
+
)
|