ChatterBot 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -82,6 +82,7 @@ class MongoDatabaseAdapter(StorageAdapter):
82
82
  exclude_text_words = kwargs.pop('exclude_text_words', [])
83
83
  persona_not_startswith = kwargs.pop('persona_not_startswith', None)
84
84
  search_text_contains = kwargs.pop('search_text_contains', None)
85
+ search_in_response_to_contains = kwargs.pop('search_in_response_to_contains', None)
85
86
 
86
87
  if tags:
87
88
  kwargs['tags'] = {
@@ -127,6 +128,12 @@ class MongoDatabaseAdapter(StorageAdapter):
127
128
  ])
128
129
  kwargs['search_text'] = re.compile(or_regex)
129
130
 
131
+ if search_in_response_to_contains:
132
+ or_regex = '|'.join([
133
+ '{}'.format(re.escape(word)) for word in search_in_response_to_contains.split(' ')
134
+ ])
135
+ kwargs['search_in_response_to'] = re.compile(or_regex)
136
+
130
137
  mongo_ordering = []
131
138
 
132
139
  if order_by:
@@ -159,13 +166,6 @@ class MongoDatabaseAdapter(StorageAdapter):
159
166
  if 'tags' in kwargs:
160
167
  kwargs['tags'] = list(set(kwargs['tags']))
161
168
 
162
- if 'search_text' not in kwargs:
163
- kwargs['search_text'] = self.tagger.get_text_index_string(kwargs['text'])
164
-
165
- if 'search_in_response_to' not in kwargs:
166
- if kwargs.get('in_response_to'):
167
- kwargs['search_in_response_to'] = self.tagger.get_text_index_string(kwargs['in_response_to'])
168
-
169
169
  inserted = self.statements.insert_one(kwargs)
170
170
 
171
171
  kwargs['id'] = inserted.inserted_id
@@ -178,20 +178,6 @@ class MongoDatabaseAdapter(StorageAdapter):
178
178
  """
179
179
  create_statements = []
180
180
 
181
- # Check if any statements already have a search text
182
- have_search_text = any(statement.search_text for statement in statements)
183
-
184
- # Generate search text values in bulk
185
- if not have_search_text:
186
- search_text_documents = self.tagger.as_nlp_pipeline([statement.text for statement in statements])
187
- response_search_text_documents = self.tagger.as_nlp_pipeline([statement.in_response_to or '' for statement in statements])
188
-
189
- for statement, search_text_document, response_search_text_document in zip(
190
- statements, search_text_documents, response_search_text_documents
191
- ):
192
- statement.search_text = search_text_document._.search_index
193
- statement.search_in_response_to = response_search_text_document._.search_index
194
-
195
181
  for statement in statements:
196
182
  statement_data = statement.serialize()
197
183
  tag_data = list(set(statement_data.pop('tags', [])))
@@ -206,11 +192,6 @@ class MongoDatabaseAdapter(StorageAdapter):
206
192
  data.pop('id', None)
207
193
  data.pop('tags', None)
208
194
 
209
- data['search_text'] = self.tagger.get_text_index_string(data['text'])
210
-
211
- if data.get('in_response_to'):
212
- data['search_in_response_to'] = self.tagger.get_text_index_string(data['in_response_to'])
213
-
214
195
  update_data = {
215
196
  '$set': data
216
197
  }
@@ -0,0 +1,390 @@
1
+ from datetime import datetime
2
+ from chatterbot.storage import StorageAdapter
3
+ from chatterbot.conversation import Statement as StatementObject
4
+
5
+
6
+ # TODO: This list may not be exhaustive.
7
+ # Is there a full list of characters reserved by redis?
8
+ REDIS_ESCAPE_CHARACTERS = {
9
+ '\\': '\\\\',
10
+ ':': '\\:',
11
+ '|': '\\|',
12
+ '%': '\\%',
13
+ '!': '\\!',
14
+ '-': '\\-',
15
+ }
16
+
17
+ REDIS_TRANSLATION_TABLE = str.maketrans(REDIS_ESCAPE_CHARACTERS)
18
+
19
+ def _escape_redis_special_characters(text):
20
+ """
21
+ Escape special characters in a string that are used in redis queries.
22
+ """
23
+ return text.translate(REDIS_TRANSLATION_TABLE)
24
+
25
+
26
+ class RedisVectorStorageAdapter(StorageAdapter):
27
+ """
28
+ .. warning:: BETA feature (Released March, 2025): this storage adapter is new
29
+ and experimental. Its functionality and default parameters might change
30
+ in the future and its behavior has not yet been finalized.
31
+
32
+ The RedisVectorStorageAdapter allows ChatterBot to store conversation
33
+ data in a redis instance.
34
+
35
+ All parameters are optional, by default a redis instance on localhost is assumed.
36
+
37
+ :keyword database_uri: eg: redis://localhost:6379/0',
38
+ The database_uri can be specified to choose a redis instance.
39
+ :type database_uri: str
40
+ """
41
+
42
+ class RedisMetaDataType:
43
+ """
44
+ Subclass for redis config metadata type enumerator.
45
+ """
46
+ TAG = 'tag'
47
+ TEXT = 'text'
48
+ NUMERIC = 'numeric'
49
+
50
+ def __init__(self, **kwargs):
51
+ super().__init__(**kwargs)
52
+ from chatterbot.vectorstores import RedisVectorStore
53
+ from langchain_redis import RedisConfig # RedisVectorStore
54
+ from langchain_huggingface import HuggingFaceEmbeddings
55
+
56
+ self.database_uri = kwargs.get('database_uri', 'redis://localhost:6379/0')
57
+
58
+ config = RedisConfig(
59
+ index_name='chatterbot',
60
+ redis_url=self.database_uri,
61
+ content_field='in_response_to',
62
+ metadata_schema=[
63
+ {
64
+ 'name': 'conversation',
65
+ 'type': self.RedisMetaDataType.TAG,
66
+ },
67
+ {
68
+ 'name': 'text',
69
+ 'type': self.RedisMetaDataType.TEXT,
70
+ },
71
+ {
72
+ 'name': 'created_at',
73
+ 'type': self.RedisMetaDataType.NUMERIC,
74
+ },
75
+ {
76
+ 'name': 'persona',
77
+ 'type': self.RedisMetaDataType.TEXT,
78
+ },
79
+ {
80
+ 'name': 'tags',
81
+ 'type': self.RedisMetaDataType.TAG,
82
+ # 'separator': '|'
83
+ },
84
+ ],
85
+ )
86
+
87
+ # TODO should this call from_existing_index if connecting to
88
+ # a redis instance that already contains data?
89
+
90
+ self.logger.info('Loading HuggingFace embeddings')
91
+
92
+ # TODO: Research different embeddings
93
+ # https://python.langchain.com/docs/integrations/vectorstores/mongodb_atlas/#initialization
94
+
95
+ embeddings = HuggingFaceEmbeddings(
96
+ model_name='sentence-transformers/all-mpnet-base-v2'
97
+ )
98
+
99
+ self.logger.info('Creating Redis Vector Store')
100
+
101
+ self.vector_store = RedisVectorStore(embeddings, config=config)
102
+
103
+ def get_statement_model(self):
104
+ """
105
+ Return the statement model.
106
+ """
107
+ from langchain_core.documents import Document
108
+
109
+ return Document
110
+
111
+ def model_to_object(self, document):
112
+
113
+ in_response_to = document.page_content
114
+
115
+ # If the value is an empty string, set it to None
116
+ # to match the expected type (the vector store does
117
+ # not use null values)
118
+ if in_response_to == '':
119
+ in_response_to = None
120
+
121
+ values = {
122
+ 'in_response_to': in_response_to,
123
+ }
124
+
125
+ if document.id:
126
+ values['id'] = document.id
127
+
128
+ values.update(document.metadata)
129
+
130
+ tags = values['tags']
131
+ values['tags'] = list(set(tags.split('|') if tags else []))
132
+
133
+ return StatementObject(**values)
134
+
135
+ def count(self):
136
+ """
137
+ Return the number of entries in the database.
138
+ """
139
+
140
+ '''
141
+ TODO
142
+ faiss_vector_store = FAISS(
143
+ embedding_function=embedding_function,
144
+ index=IndexFlatL2(embedding_size),
145
+ docstore=InMemoryDocstore(),
146
+ index_to_docstore_id={}
147
+ )
148
+ doc_count = faiss_vector_store.index.ntotal
149
+ '''
150
+
151
+ client = self.vector_store.index.client
152
+ return client.dbsize()
153
+
154
+ def remove(self, statement):
155
+ """
156
+ Removes the statement that matches the input text.
157
+ Removes any responses from statements where the response text matches
158
+ the input text.
159
+ """
160
+ self.vector_store.delete(ids=[statement.id.split(':')[1]])
161
+
162
+ def filter(self, page_size=4, **kwargs):
163
+ """
164
+ Returns a list of objects from the database.
165
+ The kwargs parameter can contain any number
166
+ of attributes. Only objects which contain all
167
+ listed attributes and in which all values match
168
+ for all listed attributes will be returned.
169
+
170
+ kwargs:
171
+ - conversation
172
+ - persona
173
+ - tags
174
+ - in_response_to
175
+ - text
176
+ - exclude_text
177
+ - exclude_text_words
178
+ - persona_not_startswith
179
+ - search_in_response_to_contains
180
+ - order_by
181
+ """
182
+ from redisvl.query.filter import Tag, Text
183
+
184
+ # https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/query_syntax/
185
+ filter_condition = None
186
+
187
+ if 'in_response_to' in kwargs:
188
+ filter_condition = Text('in_response_to') == kwargs['in_response_to']
189
+
190
+ if 'conversation' in kwargs:
191
+ query = Tag('conversation') == kwargs['conversation']
192
+ if filter_condition:
193
+ filter_condition &= query
194
+ else:
195
+ filter_condition = query
196
+
197
+ if 'persona' in kwargs:
198
+ query = Tag('persona') == kwargs['persona']
199
+ if filter_condition:
200
+ filter_condition &= query
201
+ else:
202
+ filter_condition = query
203
+
204
+ if 'tags' in kwargs:
205
+ query = Tag('tags') == kwargs['tags']
206
+ if filter_condition:
207
+ filter_condition &= query
208
+ else:
209
+ filter_condition = query
210
+
211
+ if 'exclude_text' in kwargs:
212
+ query = Text('text') != '|'.join([
213
+ f'%%{text}%%' for text in kwargs['exclude_text']
214
+ ])
215
+ if filter_condition:
216
+ filter_condition &= query
217
+ else:
218
+ filter_condition = query
219
+
220
+ if 'exclude_text_words' in kwargs:
221
+ _query = '|'.join([
222
+ f'%%{text}%%' for text in kwargs['exclude_text_words']
223
+ ])
224
+ query = Text('text') % f'-({ _query })'
225
+ if filter_condition:
226
+ filter_condition &= query
227
+ else:
228
+ filter_condition = query
229
+
230
+ if 'persona_not_startswith' in kwargs:
231
+ _query = _escape_redis_special_characters(kwargs['persona_not_startswith'])
232
+ query = Text('persona') % f'-(%%{_query}%%)'
233
+ if filter_condition:
234
+ filter_condition &= query
235
+ else:
236
+ filter_condition = query
237
+
238
+ if 'text' in kwargs:
239
+ _query = _escape_redis_special_characters(kwargs['text'])
240
+ query = Text('text') % '|'.join([f'%%{_q}%%' for _q in _query.split()])
241
+ if filter_condition:
242
+ filter_condition &= query
243
+ else:
244
+ filter_condition = query
245
+
246
+ ordering = kwargs.get('order_by', None)
247
+
248
+ if ordering:
249
+ ordering = ','.join(ordering)
250
+
251
+ if 'search_in_response_to_contains' in kwargs:
252
+ _search_text = kwargs.get('search_in_response_to_contains', '')
253
+
254
+ # TODO similarity_search_with_score
255
+ documents = self.vector_store.similarity_search(
256
+ _search_text,
257
+ k=page_size, # The number of results to return
258
+ return_all=True, # Include the full document with IDs
259
+ filter=filter_condition,
260
+ sort_by=ordering
261
+ )
262
+ else:
263
+ documents = self.vector_store.query_search(
264
+ k=page_size,
265
+ filter=filter_condition,
266
+ sort_by=ordering
267
+ )
268
+
269
+ return [self.model_to_object(document) for document in documents]
270
+
271
+ def create(
272
+ self,
273
+ text,
274
+ in_response_to=None,
275
+ tags=None,
276
+ **kwargs
277
+ ):
278
+ """
279
+ Creates a new statement matching the keyword arguments specified.
280
+ Returns the created statement.
281
+ """
282
+ # from langchain_community.vectorstores.redis.constants import REDIS_TAG_SEPARATOR
283
+
284
+ _default_date = datetime.now()
285
+
286
+ metadata = {
287
+ 'text': text,
288
+ 'category': kwargs.get('category', ''),
289
+ # NOTE: `created_at` must have a valid numeric value or results will
290
+ # not be returned for similarity_search for some reason
291
+ 'created_at': kwargs.get('created_at') or int(_default_date.strftime('%y%m%d')),
292
+ 'tags': '|'.join(tags) if tags else '',
293
+ 'conversation': kwargs.get('conversation', ''),
294
+ 'persona': kwargs.get('persona', ''),
295
+ }
296
+
297
+ ids = self.vector_store.add_texts([in_response_to or ''], [metadata])
298
+
299
+ metadata['created_at'] = _default_date
300
+ metadata['tags'] = tags or []
301
+ metadata.pop('text')
302
+ statement = StatementObject(
303
+ id=ids[0],
304
+ text=text,
305
+ **metadata
306
+ )
307
+ return statement
308
+
309
+ def create_many(self, statements):
310
+ """
311
+ Creates multiple statement entries.
312
+ """
313
+ Document = self.get_statement_model()
314
+ documents = [
315
+ Document(
316
+ page_content=statement.in_response_to or '',
317
+ metadata={
318
+ 'text': statement.text,
319
+ 'conversation': statement.conversation or '',
320
+ 'created_at': int(statement.created_at.strftime('%y%m%d')),
321
+ 'persona': statement.persona or '',
322
+ 'tags': '|'.join(statement.tags) if statement.tags else '',
323
+ }
324
+ ) for statement in statements
325
+ ]
326
+
327
+ self.logger.info('Adding documents to the vector store')
328
+
329
+ self.vector_store.add_documents(documents)
330
+
331
+ def update(self, statement):
332
+ """
333
+ Modifies an entry in the database.
334
+ Creates an entry if one does not exist.
335
+ """
336
+ metadata = {
337
+ 'text': statement.text,
338
+ 'conversation': statement.conversation or '',
339
+ 'created_at': int(statement.created_at.strftime('%y%m%d')),
340
+ 'persona': statement.persona or '',
341
+ 'tags': '|'.join(statement.tags) if statement.tags else '',
342
+ }
343
+
344
+ Document = self.get_statement_model()
345
+ document = Document(
346
+ page_content=statement.in_response_to or '',
347
+ metadata=metadata,
348
+ )
349
+
350
+ if statement.id:
351
+ self.vector_store.add_texts(
352
+ [document.page_content], [metadata], keys=[statement.id.split(':')[1]]
353
+ )
354
+ else:
355
+ self.vector_store.add_documents([document])
356
+
357
+ def get_random(self):
358
+ """
359
+ Returns a random statement from the database.
360
+ """
361
+ client = self.vector_store.index.client
362
+
363
+ random_key = client.randomkey()
364
+
365
+ if random_key:
366
+ random_id = random_key.decode().split(':')[1]
367
+
368
+ documents = self.vector_store.get_by_ids([random_id])
369
+
370
+ if documents:
371
+ return self.model_to_object(documents[0])
372
+
373
+ raise self.EmptyDatabaseException()
374
+
375
+ def drop(self):
376
+ """
377
+ Remove all existing documents from the database.
378
+ """
379
+ index_name = self.vector_store.config.index_name
380
+ client = self.vector_store.index.client
381
+
382
+ for key in client.scan_iter(f'{index_name}:*'):
383
+ # self.vector_store.index.drop_keys(key)
384
+ client.delete(key)
385
+
386
+ # Commenting this out for now because there is no step
387
+ # to recreate the index after it is dropped (really what
388
+ # we want is to delete all the keys in the index, but
389
+ # keep the index itself)
390
+ # self.vector_store.index.delete(drop=True)
@@ -114,8 +114,8 @@ class SQLStorageAdapter(StorageAdapter):
114
114
  record = query.first()
115
115
 
116
116
  session.delete(record)
117
-
118
- self._session_finish(session)
117
+ session.commit()
118
+ session.close()
119
119
 
120
120
  def filter(self, **kwargs):
121
121
  """
@@ -139,6 +139,7 @@ class SQLStorageAdapter(StorageAdapter):
139
139
  exclude_text_words = kwargs.pop('exclude_text_words', [])
140
140
  persona_not_startswith = kwargs.pop('persona_not_startswith', None)
141
141
  search_text_contains = kwargs.pop('search_text_contains', None)
142
+ search_in_response_to_contains = kwargs.pop('search_in_response_to_contains', None)
142
143
 
143
144
  # Convert a single sting into a list if only one tag is provided
144
145
  if type(tags) == str:
@@ -180,6 +181,14 @@ class SQLStorageAdapter(StorageAdapter):
180
181
  or_(*or_query)
181
182
  )
182
183
 
184
+ if search_in_response_to_contains:
185
+ or_query = [
186
+ Statement.search_in_response_to.contains(word) for word in search_in_response_to_contains.split(' ')
187
+ ]
188
+ statements = statements.filter(
189
+ or_(*or_query)
190
+ )
191
+
183
192
  if order_by:
184
193
 
185
194
  if 'created_at' in order_by:
@@ -196,7 +205,15 @@ class SQLStorageAdapter(StorageAdapter):
196
205
 
197
206
  session.close()
198
207
 
199
- def create(self, **kwargs):
208
+ def create(
209
+ self,
210
+ text,
211
+ in_response_to=None,
212
+ tags=None,
213
+ search_text=None,
214
+ search_in_response_to=None,
215
+ **kwargs
216
+ ):
200
217
  """
201
218
  Creates a new statement matching the keyword arguments specified.
202
219
  Returns the created statement.
@@ -206,19 +223,25 @@ class SQLStorageAdapter(StorageAdapter):
206
223
 
207
224
  session = self.Session()
208
225
 
209
- tags = set(kwargs.pop('tags', []))
210
-
211
- if 'search_text' not in kwargs:
212
- kwargs['search_text'] = self.tagger.get_text_index_string(kwargs['text'])
213
-
214
- if 'search_in_response_to' not in kwargs:
215
- in_response_to = kwargs.get('in_response_to')
216
- if in_response_to:
217
- kwargs['search_in_response_to'] = self.tagger.get_text_index_string(in_response_to)
218
-
219
- statement = Statement(**kwargs)
220
-
221
- for tag_name in tags:
226
+ if search_text is None:
227
+ if self.raise_on_missing_search_text:
228
+ raise Exception('generate a search_text value')
229
+
230
+ if search_in_response_to is None and in_response_to is not None:
231
+ if self.raise_on_missing_search_text:
232
+ raise Exception('generate a search_in_response_to value')
233
+
234
+ statement = Statement(
235
+ text=text,
236
+ in_response_to=in_response_to,
237
+ search_text=search_text,
238
+ search_in_response_to=search_in_response_to,
239
+ **kwargs
240
+ )
241
+
242
+ tags = frozenset(tags) if tags else frozenset()
243
+ for tag_name in frozenset(tags):
244
+ # TODO: Query existing tags in bulk
222
245
  tag = session.query(Tag).filter_by(name=tag_name).first()
223
246
 
224
247
  if not tag:
@@ -235,7 +258,7 @@ class SQLStorageAdapter(StorageAdapter):
235
258
 
236
259
  statement_object = self.model_to_object(statement)
237
260
 
238
- self._session_finish(session)
261
+ session.close()
239
262
 
240
263
  return statement_object
241
264
 
@@ -256,14 +279,8 @@ class SQLStorageAdapter(StorageAdapter):
256
279
 
257
280
  # Generate search text values in bulk
258
281
  if not have_search_text:
259
- search_text_documents = self.tagger.as_nlp_pipeline([statement.text for statement in statements])
260
- response_search_text_documents = self.tagger.as_nlp_pipeline([statement.in_response_to or '' for statement in statements])
261
-
262
- for statement, search_text_document, response_search_text_document in zip(
263
- statements, search_text_documents, response_search_text_documents
264
- ):
265
- statement.search_text = search_text_document._.search_index
266
- statement.search_in_response_to = response_search_text_document._.search_index
282
+ if self.raise_on_missing_search_text:
283
+ raise Exception('generate bulk_search_text values')
267
284
 
268
285
  for statement in statements:
269
286
 
@@ -305,48 +322,50 @@ class SQLStorageAdapter(StorageAdapter):
305
322
  Statement = self.get_model('statement')
306
323
  Tag = self.get_model('tag')
307
324
 
308
- if statement is not None:
309
- session = self.Session()
310
- record = None
311
-
312
- if hasattr(statement, 'id') and statement.id is not None:
313
- record = session.query(Statement).get(statement.id)
314
- else:
315
- record = session.query(Statement).filter(
316
- Statement.text == statement.text,
317
- Statement.conversation == statement.conversation,
318
- ).first()
319
-
320
- # Create a new statement entry if one does not already exist
321
- if not record:
322
- record = Statement(
323
- text=statement.text,
324
- conversation=statement.conversation,
325
- persona=statement.persona
326
- )
325
+ session = self.Session()
326
+ record = None
327
327
 
328
- # Update the response value
329
- record.in_response_to = statement.in_response_to
328
+ if hasattr(statement, 'id') and statement.id is not None:
329
+ record = session.query(Statement).get(statement.id)
330
+ else:
331
+ record = session.query(Statement).filter(
332
+ Statement.text == statement.text,
333
+ Statement.conversation == statement.conversation,
334
+ ).first()
335
+
336
+ # Create a new statement entry if one does not already exist
337
+ if not record:
338
+ record = Statement(
339
+ text=statement.text,
340
+ conversation=statement.conversation,
341
+ persona=statement.persona
342
+ )
330
343
 
331
- record.created_at = statement.created_at
344
+ # Update the response value
345
+ record.in_response_to = statement.in_response_to
332
346
 
333
- record.search_text = self.tagger.get_text_index_string(statement.text)
347
+ record.created_at = statement.created_at
334
348
 
335
- if statement.in_response_to:
336
- record.search_in_response_to = self.tagger.get_text_index_string(statement.in_response_to)
349
+ if not statement.search_text:
350
+ if self.raise_on_missing_search_text:
351
+ raise Exception('update issued without search_text value')
337
352
 
338
- for tag_name in statement.get_tags():
339
- tag = session.query(Tag).filter_by(name=tag_name).first()
353
+ if statement.in_response_to and not statement.search_in_response_to:
354
+ if self.raise_on_missing_search_text:
355
+ raise Exception('update issued without search_in_response_to value')
340
356
 
341
- if not tag:
342
- # Create the record
343
- tag = Tag(name=tag_name)
357
+ for tag_name in statement.get_tags():
358
+ tag = session.query(Tag).filter_by(name=tag_name).first()
344
359
 
345
- record.tags.append(tag)
360
+ if not tag:
361
+ # Create the record
362
+ tag = Tag(name=tag_name)
346
363
 
347
- session.add(record)
364
+ record.tags.append(tag)
348
365
 
349
- self._session_finish(session)
366
+ session.add(record)
367
+ session.commit()
368
+ session.close()
350
369
 
351
370
  def get_random(self):
352
371
  """
@@ -388,13 +407,3 @@ class SQLStorageAdapter(StorageAdapter):
388
407
  """
389
408
  from chatterbot.ext.sqlalchemy_app.models import Base
390
409
  Base.metadata.create_all(self.engine)
391
-
392
- def _session_finish(self, session, statement_text=None):
393
- from sqlalchemy.exc import InvalidRequestError
394
- try:
395
- session.commit()
396
- except InvalidRequestError:
397
- # Log the statement text and the exception
398
- self.logger.exception(statement_text)
399
- finally:
400
- session.close()