ChatterBot 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatterbot/__init__.py +1 -1
- chatterbot/chatterbot.py +41 -8
- chatterbot/comparisons.py +32 -15
- chatterbot/logic/best_match.py +42 -35
- chatterbot/logic/specific_response.py +52 -9
- chatterbot/logic/unit_conversion.py +4 -3
- chatterbot/response_selection.py +1 -1
- chatterbot/search.py +65 -17
- chatterbot/storage/__init__.py +2 -0
- chatterbot/storage/django_storage.py +13 -23
- chatterbot/storage/mongodb.py +7 -26
- chatterbot/storage/redis.py +390 -0
- chatterbot/storage/sql_storage.py +77 -68
- chatterbot/storage/storage_adapter.py +9 -7
- chatterbot/trainers.py +3 -3
- chatterbot/vectorstores.py +74 -0
- {ChatterBot-1.2.1.dist-info → chatterbot-1.2.3.dist-info}/METADATA +9 -3
- {ChatterBot-1.2.1.dist-info → chatterbot-1.2.3.dist-info}/RECORD +21 -19
- {ChatterBot-1.2.1.dist-info → chatterbot-1.2.3.dist-info}/WHEEL +1 -1
- {ChatterBot-1.2.1.dist-info → chatterbot-1.2.3.dist-info}/LICENSE +0 -0
- {ChatterBot-1.2.1.dist-info → chatterbot-1.2.3.dist-info}/top_level.txt +0 -0
chatterbot/storage/mongodb.py
CHANGED
@@ -82,6 +82,7 @@ class MongoDatabaseAdapter(StorageAdapter):
|
|
82
82
|
exclude_text_words = kwargs.pop('exclude_text_words', [])
|
83
83
|
persona_not_startswith = kwargs.pop('persona_not_startswith', None)
|
84
84
|
search_text_contains = kwargs.pop('search_text_contains', None)
|
85
|
+
search_in_response_to_contains = kwargs.pop('search_in_response_to_contains', None)
|
85
86
|
|
86
87
|
if tags:
|
87
88
|
kwargs['tags'] = {
|
@@ -127,6 +128,12 @@ class MongoDatabaseAdapter(StorageAdapter):
|
|
127
128
|
])
|
128
129
|
kwargs['search_text'] = re.compile(or_regex)
|
129
130
|
|
131
|
+
if search_in_response_to_contains:
|
132
|
+
or_regex = '|'.join([
|
133
|
+
'{}'.format(re.escape(word)) for word in search_in_response_to_contains.split(' ')
|
134
|
+
])
|
135
|
+
kwargs['search_in_response_to'] = re.compile(or_regex)
|
136
|
+
|
130
137
|
mongo_ordering = []
|
131
138
|
|
132
139
|
if order_by:
|
@@ -159,13 +166,6 @@ class MongoDatabaseAdapter(StorageAdapter):
|
|
159
166
|
if 'tags' in kwargs:
|
160
167
|
kwargs['tags'] = list(set(kwargs['tags']))
|
161
168
|
|
162
|
-
if 'search_text' not in kwargs:
|
163
|
-
kwargs['search_text'] = self.tagger.get_text_index_string(kwargs['text'])
|
164
|
-
|
165
|
-
if 'search_in_response_to' not in kwargs:
|
166
|
-
if kwargs.get('in_response_to'):
|
167
|
-
kwargs['search_in_response_to'] = self.tagger.get_text_index_string(kwargs['in_response_to'])
|
168
|
-
|
169
169
|
inserted = self.statements.insert_one(kwargs)
|
170
170
|
|
171
171
|
kwargs['id'] = inserted.inserted_id
|
@@ -178,20 +178,6 @@ class MongoDatabaseAdapter(StorageAdapter):
|
|
178
178
|
"""
|
179
179
|
create_statements = []
|
180
180
|
|
181
|
-
# Check if any statements already have a search text
|
182
|
-
have_search_text = any(statement.search_text for statement in statements)
|
183
|
-
|
184
|
-
# Generate search text values in bulk
|
185
|
-
if not have_search_text:
|
186
|
-
search_text_documents = self.tagger.as_nlp_pipeline([statement.text for statement in statements])
|
187
|
-
response_search_text_documents = self.tagger.as_nlp_pipeline([statement.in_response_to or '' for statement in statements])
|
188
|
-
|
189
|
-
for statement, search_text_document, response_search_text_document in zip(
|
190
|
-
statements, search_text_documents, response_search_text_documents
|
191
|
-
):
|
192
|
-
statement.search_text = search_text_document._.search_index
|
193
|
-
statement.search_in_response_to = response_search_text_document._.search_index
|
194
|
-
|
195
181
|
for statement in statements:
|
196
182
|
statement_data = statement.serialize()
|
197
183
|
tag_data = list(set(statement_data.pop('tags', [])))
|
@@ -206,11 +192,6 @@ class MongoDatabaseAdapter(StorageAdapter):
|
|
206
192
|
data.pop('id', None)
|
207
193
|
data.pop('tags', None)
|
208
194
|
|
209
|
-
data['search_text'] = self.tagger.get_text_index_string(data['text'])
|
210
|
-
|
211
|
-
if data.get('in_response_to'):
|
212
|
-
data['search_in_response_to'] = self.tagger.get_text_index_string(data['in_response_to'])
|
213
|
-
|
214
195
|
update_data = {
|
215
196
|
'$set': data
|
216
197
|
}
|
@@ -0,0 +1,390 @@
|
|
1
|
+
from datetime import datetime
|
2
|
+
from chatterbot.storage import StorageAdapter
|
3
|
+
from chatterbot.conversation import Statement as StatementObject
|
4
|
+
|
5
|
+
|
6
|
+
# TODO: This list may not be exhaustive.
|
7
|
+
# Is there a full list of characters reserved by redis?
|
8
|
+
REDIS_ESCAPE_CHARACTERS = {
|
9
|
+
'\\': '\\\\',
|
10
|
+
':': '\\:',
|
11
|
+
'|': '\\|',
|
12
|
+
'%': '\\%',
|
13
|
+
'!': '\\!',
|
14
|
+
'-': '\\-',
|
15
|
+
}
|
16
|
+
|
17
|
+
REDIS_TRANSLATION_TABLE = str.maketrans(REDIS_ESCAPE_CHARACTERS)
|
18
|
+
|
19
|
+
def _escape_redis_special_characters(text):
|
20
|
+
"""
|
21
|
+
Escape special characters in a string that are used in redis queries.
|
22
|
+
"""
|
23
|
+
return text.translate(REDIS_TRANSLATION_TABLE)
|
24
|
+
|
25
|
+
|
26
|
+
class RedisVectorStorageAdapter(StorageAdapter):
|
27
|
+
"""
|
28
|
+
.. warning:: BETA feature (Released March, 2025): this storage adapter is new
|
29
|
+
and experimental. Its functionality and default parameters might change
|
30
|
+
in the future and its behavior has not yet been finalized.
|
31
|
+
|
32
|
+
The RedisVectorStorageAdapter allows ChatterBot to store conversation
|
33
|
+
data in a redis instance.
|
34
|
+
|
35
|
+
All parameters are optional, by default a redis instance on localhost is assumed.
|
36
|
+
|
37
|
+
:keyword database_uri: eg: redis://localhost:6379/0',
|
38
|
+
The database_uri can be specified to choose a redis instance.
|
39
|
+
:type database_uri: str
|
40
|
+
"""
|
41
|
+
|
42
|
+
class RedisMetaDataType:
|
43
|
+
"""
|
44
|
+
Subclass for redis config metadata type enumerator.
|
45
|
+
"""
|
46
|
+
TAG = 'tag'
|
47
|
+
TEXT = 'text'
|
48
|
+
NUMERIC = 'numeric'
|
49
|
+
|
50
|
+
def __init__(self, **kwargs):
|
51
|
+
super().__init__(**kwargs)
|
52
|
+
from chatterbot.vectorstores import RedisVectorStore
|
53
|
+
from langchain_redis import RedisConfig # RedisVectorStore
|
54
|
+
from langchain_huggingface import HuggingFaceEmbeddings
|
55
|
+
|
56
|
+
self.database_uri = kwargs.get('database_uri', 'redis://localhost:6379/0')
|
57
|
+
|
58
|
+
config = RedisConfig(
|
59
|
+
index_name='chatterbot',
|
60
|
+
redis_url=self.database_uri,
|
61
|
+
content_field='in_response_to',
|
62
|
+
metadata_schema=[
|
63
|
+
{
|
64
|
+
'name': 'conversation',
|
65
|
+
'type': self.RedisMetaDataType.TAG,
|
66
|
+
},
|
67
|
+
{
|
68
|
+
'name': 'text',
|
69
|
+
'type': self.RedisMetaDataType.TEXT,
|
70
|
+
},
|
71
|
+
{
|
72
|
+
'name': 'created_at',
|
73
|
+
'type': self.RedisMetaDataType.NUMERIC,
|
74
|
+
},
|
75
|
+
{
|
76
|
+
'name': 'persona',
|
77
|
+
'type': self.RedisMetaDataType.TEXT,
|
78
|
+
},
|
79
|
+
{
|
80
|
+
'name': 'tags',
|
81
|
+
'type': self.RedisMetaDataType.TAG,
|
82
|
+
# 'separator': '|'
|
83
|
+
},
|
84
|
+
],
|
85
|
+
)
|
86
|
+
|
87
|
+
# TODO should this call from_existing_index if connecting to
|
88
|
+
# a redis instance that already contains data?
|
89
|
+
|
90
|
+
self.logger.info('Loading HuggingFace embeddings')
|
91
|
+
|
92
|
+
# TODO: Research different embeddings
|
93
|
+
# https://python.langchain.com/docs/integrations/vectorstores/mongodb_atlas/#initialization
|
94
|
+
|
95
|
+
embeddings = HuggingFaceEmbeddings(
|
96
|
+
model_name='sentence-transformers/all-mpnet-base-v2'
|
97
|
+
)
|
98
|
+
|
99
|
+
self.logger.info('Creating Redis Vector Store')
|
100
|
+
|
101
|
+
self.vector_store = RedisVectorStore(embeddings, config=config)
|
102
|
+
|
103
|
+
def get_statement_model(self):
|
104
|
+
"""
|
105
|
+
Return the statement model.
|
106
|
+
"""
|
107
|
+
from langchain_core.documents import Document
|
108
|
+
|
109
|
+
return Document
|
110
|
+
|
111
|
+
def model_to_object(self, document):
|
112
|
+
|
113
|
+
in_response_to = document.page_content
|
114
|
+
|
115
|
+
# If the value is an empty string, set it to None
|
116
|
+
# to match the expected type (the vector store does
|
117
|
+
# not use null values)
|
118
|
+
if in_response_to == '':
|
119
|
+
in_response_to = None
|
120
|
+
|
121
|
+
values = {
|
122
|
+
'in_response_to': in_response_to,
|
123
|
+
}
|
124
|
+
|
125
|
+
if document.id:
|
126
|
+
values['id'] = document.id
|
127
|
+
|
128
|
+
values.update(document.metadata)
|
129
|
+
|
130
|
+
tags = values['tags']
|
131
|
+
values['tags'] = list(set(tags.split('|') if tags else []))
|
132
|
+
|
133
|
+
return StatementObject(**values)
|
134
|
+
|
135
|
+
def count(self):
|
136
|
+
"""
|
137
|
+
Return the number of entries in the database.
|
138
|
+
"""
|
139
|
+
|
140
|
+
'''
|
141
|
+
TODO
|
142
|
+
faiss_vector_store = FAISS(
|
143
|
+
embedding_function=embedding_function,
|
144
|
+
index=IndexFlatL2(embedding_size),
|
145
|
+
docstore=InMemoryDocstore(),
|
146
|
+
index_to_docstore_id={}
|
147
|
+
)
|
148
|
+
doc_count = faiss_vector_store.index.ntotal
|
149
|
+
'''
|
150
|
+
|
151
|
+
client = self.vector_store.index.client
|
152
|
+
return client.dbsize()
|
153
|
+
|
154
|
+
def remove(self, statement):
|
155
|
+
"""
|
156
|
+
Removes the statement that matches the input text.
|
157
|
+
Removes any responses from statements where the response text matches
|
158
|
+
the input text.
|
159
|
+
"""
|
160
|
+
self.vector_store.delete(ids=[statement.id.split(':')[1]])
|
161
|
+
|
162
|
+
def filter(self, page_size=4, **kwargs):
|
163
|
+
"""
|
164
|
+
Returns a list of objects from the database.
|
165
|
+
The kwargs parameter can contain any number
|
166
|
+
of attributes. Only objects which contain all
|
167
|
+
listed attributes and in which all values match
|
168
|
+
for all listed attributes will be returned.
|
169
|
+
|
170
|
+
kwargs:
|
171
|
+
- conversation
|
172
|
+
- persona
|
173
|
+
- tags
|
174
|
+
- in_response_to
|
175
|
+
- text
|
176
|
+
- exclude_text
|
177
|
+
- exclude_text_words
|
178
|
+
- persona_not_startswith
|
179
|
+
- search_in_response_to_contains
|
180
|
+
- order_by
|
181
|
+
"""
|
182
|
+
from redisvl.query.filter import Tag, Text
|
183
|
+
|
184
|
+
# https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/query_syntax/
|
185
|
+
filter_condition = None
|
186
|
+
|
187
|
+
if 'in_response_to' in kwargs:
|
188
|
+
filter_condition = Text('in_response_to') == kwargs['in_response_to']
|
189
|
+
|
190
|
+
if 'conversation' in kwargs:
|
191
|
+
query = Tag('conversation') == kwargs['conversation']
|
192
|
+
if filter_condition:
|
193
|
+
filter_condition &= query
|
194
|
+
else:
|
195
|
+
filter_condition = query
|
196
|
+
|
197
|
+
if 'persona' in kwargs:
|
198
|
+
query = Tag('persona') == kwargs['persona']
|
199
|
+
if filter_condition:
|
200
|
+
filter_condition &= query
|
201
|
+
else:
|
202
|
+
filter_condition = query
|
203
|
+
|
204
|
+
if 'tags' in kwargs:
|
205
|
+
query = Tag('tags') == kwargs['tags']
|
206
|
+
if filter_condition:
|
207
|
+
filter_condition &= query
|
208
|
+
else:
|
209
|
+
filter_condition = query
|
210
|
+
|
211
|
+
if 'exclude_text' in kwargs:
|
212
|
+
query = Text('text') != '|'.join([
|
213
|
+
f'%%{text}%%' for text in kwargs['exclude_text']
|
214
|
+
])
|
215
|
+
if filter_condition:
|
216
|
+
filter_condition &= query
|
217
|
+
else:
|
218
|
+
filter_condition = query
|
219
|
+
|
220
|
+
if 'exclude_text_words' in kwargs:
|
221
|
+
_query = '|'.join([
|
222
|
+
f'%%{text}%%' for text in kwargs['exclude_text_words']
|
223
|
+
])
|
224
|
+
query = Text('text') % f'-({ _query })'
|
225
|
+
if filter_condition:
|
226
|
+
filter_condition &= query
|
227
|
+
else:
|
228
|
+
filter_condition = query
|
229
|
+
|
230
|
+
if 'persona_not_startswith' in kwargs:
|
231
|
+
_query = _escape_redis_special_characters(kwargs['persona_not_startswith'])
|
232
|
+
query = Text('persona') % f'-(%%{_query}%%)'
|
233
|
+
if filter_condition:
|
234
|
+
filter_condition &= query
|
235
|
+
else:
|
236
|
+
filter_condition = query
|
237
|
+
|
238
|
+
if 'text' in kwargs:
|
239
|
+
_query = _escape_redis_special_characters(kwargs['text'])
|
240
|
+
query = Text('text') % '|'.join([f'%%{_q}%%' for _q in _query.split()])
|
241
|
+
if filter_condition:
|
242
|
+
filter_condition &= query
|
243
|
+
else:
|
244
|
+
filter_condition = query
|
245
|
+
|
246
|
+
ordering = kwargs.get('order_by', None)
|
247
|
+
|
248
|
+
if ordering:
|
249
|
+
ordering = ','.join(ordering)
|
250
|
+
|
251
|
+
if 'search_in_response_to_contains' in kwargs:
|
252
|
+
_search_text = kwargs.get('search_in_response_to_contains', '')
|
253
|
+
|
254
|
+
# TODO similarity_search_with_score
|
255
|
+
documents = self.vector_store.similarity_search(
|
256
|
+
_search_text,
|
257
|
+
k=page_size, # The number of results to return
|
258
|
+
return_all=True, # Include the full document with IDs
|
259
|
+
filter=filter_condition,
|
260
|
+
sort_by=ordering
|
261
|
+
)
|
262
|
+
else:
|
263
|
+
documents = self.vector_store.query_search(
|
264
|
+
k=page_size,
|
265
|
+
filter=filter_condition,
|
266
|
+
sort_by=ordering
|
267
|
+
)
|
268
|
+
|
269
|
+
return [self.model_to_object(document) for document in documents]
|
270
|
+
|
271
|
+
def create(
|
272
|
+
self,
|
273
|
+
text,
|
274
|
+
in_response_to=None,
|
275
|
+
tags=None,
|
276
|
+
**kwargs
|
277
|
+
):
|
278
|
+
"""
|
279
|
+
Creates a new statement matching the keyword arguments specified.
|
280
|
+
Returns the created statement.
|
281
|
+
"""
|
282
|
+
# from langchain_community.vectorstores.redis.constants import REDIS_TAG_SEPARATOR
|
283
|
+
|
284
|
+
_default_date = datetime.now()
|
285
|
+
|
286
|
+
metadata = {
|
287
|
+
'text': text,
|
288
|
+
'category': kwargs.get('category', ''),
|
289
|
+
# NOTE: `created_at` must have a valid numeric value or results will
|
290
|
+
# not be returned for similarity_search for some reason
|
291
|
+
'created_at': kwargs.get('created_at') or int(_default_date.strftime('%y%m%d')),
|
292
|
+
'tags': '|'.join(tags) if tags else '',
|
293
|
+
'conversation': kwargs.get('conversation', ''),
|
294
|
+
'persona': kwargs.get('persona', ''),
|
295
|
+
}
|
296
|
+
|
297
|
+
ids = self.vector_store.add_texts([in_response_to or ''], [metadata])
|
298
|
+
|
299
|
+
metadata['created_at'] = _default_date
|
300
|
+
metadata['tags'] = tags or []
|
301
|
+
metadata.pop('text')
|
302
|
+
statement = StatementObject(
|
303
|
+
id=ids[0],
|
304
|
+
text=text,
|
305
|
+
**metadata
|
306
|
+
)
|
307
|
+
return statement
|
308
|
+
|
309
|
+
def create_many(self, statements):
|
310
|
+
"""
|
311
|
+
Creates multiple statement entries.
|
312
|
+
"""
|
313
|
+
Document = self.get_statement_model()
|
314
|
+
documents = [
|
315
|
+
Document(
|
316
|
+
page_content=statement.in_response_to or '',
|
317
|
+
metadata={
|
318
|
+
'text': statement.text,
|
319
|
+
'conversation': statement.conversation or '',
|
320
|
+
'created_at': int(statement.created_at.strftime('%y%m%d')),
|
321
|
+
'persona': statement.persona or '',
|
322
|
+
'tags': '|'.join(statement.tags) if statement.tags else '',
|
323
|
+
}
|
324
|
+
) for statement in statements
|
325
|
+
]
|
326
|
+
|
327
|
+
self.logger.info('Adding documents to the vector store')
|
328
|
+
|
329
|
+
self.vector_store.add_documents(documents)
|
330
|
+
|
331
|
+
def update(self, statement):
|
332
|
+
"""
|
333
|
+
Modifies an entry in the database.
|
334
|
+
Creates an entry if one does not exist.
|
335
|
+
"""
|
336
|
+
metadata = {
|
337
|
+
'text': statement.text,
|
338
|
+
'conversation': statement.conversation or '',
|
339
|
+
'created_at': int(statement.created_at.strftime('%y%m%d')),
|
340
|
+
'persona': statement.persona or '',
|
341
|
+
'tags': '|'.join(statement.tags) if statement.tags else '',
|
342
|
+
}
|
343
|
+
|
344
|
+
Document = self.get_statement_model()
|
345
|
+
document = Document(
|
346
|
+
page_content=statement.in_response_to or '',
|
347
|
+
metadata=metadata,
|
348
|
+
)
|
349
|
+
|
350
|
+
if statement.id:
|
351
|
+
self.vector_store.add_texts(
|
352
|
+
[document.page_content], [metadata], keys=[statement.id.split(':')[1]]
|
353
|
+
)
|
354
|
+
else:
|
355
|
+
self.vector_store.add_documents([document])
|
356
|
+
|
357
|
+
def get_random(self):
|
358
|
+
"""
|
359
|
+
Returns a random statement from the database.
|
360
|
+
"""
|
361
|
+
client = self.vector_store.index.client
|
362
|
+
|
363
|
+
random_key = client.randomkey()
|
364
|
+
|
365
|
+
if random_key:
|
366
|
+
random_id = random_key.decode().split(':')[1]
|
367
|
+
|
368
|
+
documents = self.vector_store.get_by_ids([random_id])
|
369
|
+
|
370
|
+
if documents:
|
371
|
+
return self.model_to_object(documents[0])
|
372
|
+
|
373
|
+
raise self.EmptyDatabaseException()
|
374
|
+
|
375
|
+
def drop(self):
|
376
|
+
"""
|
377
|
+
Remove all existing documents from the database.
|
378
|
+
"""
|
379
|
+
index_name = self.vector_store.config.index_name
|
380
|
+
client = self.vector_store.index.client
|
381
|
+
|
382
|
+
for key in client.scan_iter(f'{index_name}:*'):
|
383
|
+
# self.vector_store.index.drop_keys(key)
|
384
|
+
client.delete(key)
|
385
|
+
|
386
|
+
# Commenting this out for now because there is no step
|
387
|
+
# to recreate the index after it is dropped (really what
|
388
|
+
# we want is to delete all the keys in the index, but
|
389
|
+
# keep the index itself)
|
390
|
+
# self.vector_store.index.delete(drop=True)
|
@@ -114,8 +114,8 @@ class SQLStorageAdapter(StorageAdapter):
|
|
114
114
|
record = query.first()
|
115
115
|
|
116
116
|
session.delete(record)
|
117
|
-
|
118
|
-
|
117
|
+
session.commit()
|
118
|
+
session.close()
|
119
119
|
|
120
120
|
def filter(self, **kwargs):
|
121
121
|
"""
|
@@ -139,6 +139,7 @@ class SQLStorageAdapter(StorageAdapter):
|
|
139
139
|
exclude_text_words = kwargs.pop('exclude_text_words', [])
|
140
140
|
persona_not_startswith = kwargs.pop('persona_not_startswith', None)
|
141
141
|
search_text_contains = kwargs.pop('search_text_contains', None)
|
142
|
+
search_in_response_to_contains = kwargs.pop('search_in_response_to_contains', None)
|
142
143
|
|
143
144
|
# Convert a single sting into a list if only one tag is provided
|
144
145
|
if type(tags) == str:
|
@@ -180,6 +181,14 @@ class SQLStorageAdapter(StorageAdapter):
|
|
180
181
|
or_(*or_query)
|
181
182
|
)
|
182
183
|
|
184
|
+
if search_in_response_to_contains:
|
185
|
+
or_query = [
|
186
|
+
Statement.search_in_response_to.contains(word) for word in search_in_response_to_contains.split(' ')
|
187
|
+
]
|
188
|
+
statements = statements.filter(
|
189
|
+
or_(*or_query)
|
190
|
+
)
|
191
|
+
|
183
192
|
if order_by:
|
184
193
|
|
185
194
|
if 'created_at' in order_by:
|
@@ -196,7 +205,15 @@ class SQLStorageAdapter(StorageAdapter):
|
|
196
205
|
|
197
206
|
session.close()
|
198
207
|
|
199
|
-
def create(
|
208
|
+
def create(
|
209
|
+
self,
|
210
|
+
text,
|
211
|
+
in_response_to=None,
|
212
|
+
tags=None,
|
213
|
+
search_text=None,
|
214
|
+
search_in_response_to=None,
|
215
|
+
**kwargs
|
216
|
+
):
|
200
217
|
"""
|
201
218
|
Creates a new statement matching the keyword arguments specified.
|
202
219
|
Returns the created statement.
|
@@ -206,19 +223,25 @@ class SQLStorageAdapter(StorageAdapter):
|
|
206
223
|
|
207
224
|
session = self.Session()
|
208
225
|
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
226
|
+
if search_text is None:
|
227
|
+
if self.raise_on_missing_search_text:
|
228
|
+
raise Exception('generate a search_text value')
|
229
|
+
|
230
|
+
if search_in_response_to is None and in_response_to is not None:
|
231
|
+
if self.raise_on_missing_search_text:
|
232
|
+
raise Exception('generate a search_in_response_to value')
|
233
|
+
|
234
|
+
statement = Statement(
|
235
|
+
text=text,
|
236
|
+
in_response_to=in_response_to,
|
237
|
+
search_text=search_text,
|
238
|
+
search_in_response_to=search_in_response_to,
|
239
|
+
**kwargs
|
240
|
+
)
|
241
|
+
|
242
|
+
tags = frozenset(tags) if tags else frozenset()
|
243
|
+
for tag_name in frozenset(tags):
|
244
|
+
# TODO: Query existing tags in bulk
|
222
245
|
tag = session.query(Tag).filter_by(name=tag_name).first()
|
223
246
|
|
224
247
|
if not tag:
|
@@ -235,7 +258,7 @@ class SQLStorageAdapter(StorageAdapter):
|
|
235
258
|
|
236
259
|
statement_object = self.model_to_object(statement)
|
237
260
|
|
238
|
-
|
261
|
+
session.close()
|
239
262
|
|
240
263
|
return statement_object
|
241
264
|
|
@@ -256,14 +279,8 @@ class SQLStorageAdapter(StorageAdapter):
|
|
256
279
|
|
257
280
|
# Generate search text values in bulk
|
258
281
|
if not have_search_text:
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
for statement, search_text_document, response_search_text_document in zip(
|
263
|
-
statements, search_text_documents, response_search_text_documents
|
264
|
-
):
|
265
|
-
statement.search_text = search_text_document._.search_index
|
266
|
-
statement.search_in_response_to = response_search_text_document._.search_index
|
282
|
+
if self.raise_on_missing_search_text:
|
283
|
+
raise Exception('generate bulk_search_text values')
|
267
284
|
|
268
285
|
for statement in statements:
|
269
286
|
|
@@ -305,48 +322,50 @@ class SQLStorageAdapter(StorageAdapter):
|
|
305
322
|
Statement = self.get_model('statement')
|
306
323
|
Tag = self.get_model('tag')
|
307
324
|
|
308
|
-
|
309
|
-
|
310
|
-
record = None
|
311
|
-
|
312
|
-
if hasattr(statement, 'id') and statement.id is not None:
|
313
|
-
record = session.query(Statement).get(statement.id)
|
314
|
-
else:
|
315
|
-
record = session.query(Statement).filter(
|
316
|
-
Statement.text == statement.text,
|
317
|
-
Statement.conversation == statement.conversation,
|
318
|
-
).first()
|
319
|
-
|
320
|
-
# Create a new statement entry if one does not already exist
|
321
|
-
if not record:
|
322
|
-
record = Statement(
|
323
|
-
text=statement.text,
|
324
|
-
conversation=statement.conversation,
|
325
|
-
persona=statement.persona
|
326
|
-
)
|
325
|
+
session = self.Session()
|
326
|
+
record = None
|
327
327
|
|
328
|
-
|
329
|
-
record
|
328
|
+
if hasattr(statement, 'id') and statement.id is not None:
|
329
|
+
record = session.query(Statement).get(statement.id)
|
330
|
+
else:
|
331
|
+
record = session.query(Statement).filter(
|
332
|
+
Statement.text == statement.text,
|
333
|
+
Statement.conversation == statement.conversation,
|
334
|
+
).first()
|
335
|
+
|
336
|
+
# Create a new statement entry if one does not already exist
|
337
|
+
if not record:
|
338
|
+
record = Statement(
|
339
|
+
text=statement.text,
|
340
|
+
conversation=statement.conversation,
|
341
|
+
persona=statement.persona
|
342
|
+
)
|
330
343
|
|
331
|
-
|
344
|
+
# Update the response value
|
345
|
+
record.in_response_to = statement.in_response_to
|
332
346
|
|
333
|
-
|
347
|
+
record.created_at = statement.created_at
|
334
348
|
|
335
|
-
|
336
|
-
|
349
|
+
if not statement.search_text:
|
350
|
+
if self.raise_on_missing_search_text:
|
351
|
+
raise Exception('update issued without search_text value')
|
337
352
|
|
338
|
-
|
339
|
-
|
353
|
+
if statement.in_response_to and not statement.search_in_response_to:
|
354
|
+
if self.raise_on_missing_search_text:
|
355
|
+
raise Exception('update issued without search_in_response_to value')
|
340
356
|
|
341
|
-
|
342
|
-
|
343
|
-
tag = Tag(name=tag_name)
|
357
|
+
for tag_name in statement.get_tags():
|
358
|
+
tag = session.query(Tag).filter_by(name=tag_name).first()
|
344
359
|
|
345
|
-
|
360
|
+
if not tag:
|
361
|
+
# Create the record
|
362
|
+
tag = Tag(name=tag_name)
|
346
363
|
|
347
|
-
|
364
|
+
record.tags.append(tag)
|
348
365
|
|
349
|
-
|
366
|
+
session.add(record)
|
367
|
+
session.commit()
|
368
|
+
session.close()
|
350
369
|
|
351
370
|
def get_random(self):
|
352
371
|
"""
|
@@ -388,13 +407,3 @@ class SQLStorageAdapter(StorageAdapter):
|
|
388
407
|
"""
|
389
408
|
from chatterbot.ext.sqlalchemy_app.models import Base
|
390
409
|
Base.metadata.create_all(self.engine)
|
391
|
-
|
392
|
-
def _session_finish(self, session, statement_text=None):
|
393
|
-
from sqlalchemy.exc import InvalidRequestError
|
394
|
-
try:
|
395
|
-
session.commit()
|
396
|
-
except InvalidRequestError:
|
397
|
-
# Log the statement text and the exception
|
398
|
-
self.logger.exception(statement_text)
|
399
|
-
finally:
|
400
|
-
session.close()
|