ChatterBot 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatterbot/__init__.py +1 -1
- chatterbot/__main__.py +15 -0
- chatterbot/chatterbot.py +52 -6
- chatterbot/comparisons.py +3 -13
- chatterbot/ext/sqlalchemy_app/models.py +2 -2
- chatterbot/logic/specific_response.py +48 -9
- chatterbot/logic/time_adapter.py +3 -7
- chatterbot/logic/unit_conversion.py +4 -3
- chatterbot/storage/__init__.py +2 -0
- chatterbot/storage/redis.py +390 -0
- chatterbot/storage/sql_storage.py +1 -1
- chatterbot/tagging.py +3 -7
- chatterbot/trainers.py +297 -109
- chatterbot/utils.py +16 -25
- chatterbot/vectorstores.py +74 -0
- {chatterbot-1.2.2.dist-info → chatterbot-1.2.4.dist-info}/METADATA +16 -20
- {chatterbot-1.2.2.dist-info → chatterbot-1.2.4.dist-info}/RECORD +20 -18
- {chatterbot-1.2.2.dist-info → chatterbot-1.2.4.dist-info}/WHEEL +1 -1
- {chatterbot-1.2.2.dist-info → chatterbot-1.2.4.dist-info/licenses}/LICENSE +0 -0
- {chatterbot-1.2.2.dist-info → chatterbot-1.2.4.dist-info}/top_level.txt +0 -0
chatterbot/__init__.py
CHANGED
chatterbot/__main__.py
CHANGED
@@ -1,7 +1,16 @@
|
|
1
|
+
"""
|
2
|
+
Example usage for ChatterBot command line arguments:
|
3
|
+
|
4
|
+
python -m chatterbot --help
|
5
|
+
"""
|
6
|
+
|
1
7
|
import sys
|
2
8
|
|
3
9
|
|
4
10
|
def get_chatterbot_version():
|
11
|
+
"""
|
12
|
+
Return the version of the current package.
|
13
|
+
"""
|
5
14
|
from chatterbot import __version__
|
6
15
|
|
7
16
|
return __version__
|
@@ -10,3 +19,9 @@ def get_chatterbot_version():
|
|
10
19
|
if __name__ == '__main__':
|
11
20
|
if '--version' in sys.argv:
|
12
21
|
print(get_chatterbot_version())
|
22
|
+
elif '--help' in sys.argv:
|
23
|
+
print('usage: chatterbot [--version, --help]')
|
24
|
+
print(' --version: Print the version of ChatterBot')
|
25
|
+
print(' --help: Print this help message')
|
26
|
+
print()
|
27
|
+
print('Documentation at https://docs.chatterbot.us')
|
chatterbot/chatterbot.py
CHANGED
@@ -11,11 +11,41 @@ import spacy
|
|
11
11
|
class ChatBot(object):
|
12
12
|
"""
|
13
13
|
A conversational dialog chat bot.
|
14
|
+
|
15
|
+
:param name: A name is the only required parameter for the ChatBot class.
|
16
|
+
:type name: str
|
17
|
+
|
18
|
+
:keyword storage_adapter: The dot-notated import path to a storage adapter class.
|
19
|
+
Defaults to ``"chatterbot.storage.SQLStorageAdapter"``.
|
20
|
+
:type storage_adapter: str
|
21
|
+
|
22
|
+
:param logic_adapters: A list of dot-notated import paths to each logic adapter the bot uses.
|
23
|
+
Defaults to ``["chatterbot.logic.BestMatch"]``.
|
24
|
+
:type logic_adapters: list
|
25
|
+
|
26
|
+
:param tagger: The tagger to use for the chat bot.
|
27
|
+
Defaults to :class:`~chatterbot.tagging.PosLemmaTagger`
|
28
|
+
:type tagger: object
|
29
|
+
|
30
|
+
:param tagger_language: The language to use for the tagger.
|
31
|
+
Defaults to :class:`~chatterbot.languages.ENG`.
|
32
|
+
:type tagger_language: object
|
33
|
+
|
34
|
+
:param preprocessors: A list of preprocessor functions to use for the chat bot.
|
35
|
+
:type preprocessors: list
|
36
|
+
|
37
|
+
:param read_only: If True, the chat bot will not save any input it receives, defaults to False.
|
38
|
+
:type read_only: bool
|
39
|
+
|
40
|
+
:param logger: A ``Logger`` object.
|
41
|
+
:type logger: logging.Logger
|
14
42
|
"""
|
15
43
|
|
16
44
|
def __init__(self, name, **kwargs):
|
17
45
|
self.name = name
|
18
46
|
|
47
|
+
self.logger = kwargs.get('logger', logging.getLogger(__name__))
|
48
|
+
|
19
49
|
storage_adapter = kwargs.get('storage_adapter', 'chatterbot.storage.SQLStorageAdapter')
|
20
50
|
|
21
51
|
logic_adapters = kwargs.get('logic_adapters', [
|
@@ -30,11 +60,29 @@ class ChatBot(object):
|
|
30
60
|
|
31
61
|
self.storage = utils.initialize_class(storage_adapter, **kwargs)
|
32
62
|
|
33
|
-
|
63
|
+
tagger_language = kwargs.get('tagger_language', languages.ENG)
|
34
64
|
|
35
|
-
|
36
|
-
'
|
37
|
-
|
65
|
+
try:
|
66
|
+
Tagger = kwargs.get('tagger', PosLemmaTagger)
|
67
|
+
|
68
|
+
self.tagger = Tagger(language=tagger_language)
|
69
|
+
except IOError as io_error:
|
70
|
+
# Return a more helpful error message if possible
|
71
|
+
if "Can't find model" in str(io_error):
|
72
|
+
model_name = utils.get_model_for_language(tagger_language)
|
73
|
+
if hasattr(tagger_language, 'ENGLISH_NAME'):
|
74
|
+
language_name = tagger_language.ENGLISH_NAME
|
75
|
+
else:
|
76
|
+
language_name = tagger_language
|
77
|
+
raise self.ChatBotException(
|
78
|
+
'Setup error:\n'
|
79
|
+
f'The Spacy model for "{language_name}" language is missing.\n'
|
80
|
+
'Please install the model using the command:\n\n'
|
81
|
+
f'python -m spacy download {model_name}\n\n'
|
82
|
+
'See https://spacy.io/usage/models for more information about available models.'
|
83
|
+
) from io_error
|
84
|
+
else:
|
85
|
+
raise io_error
|
38
86
|
|
39
87
|
primary_search_algorithm = IndexedTextSearch(self, **kwargs)
|
40
88
|
text_search_algorithm = TextSearch(self, **kwargs)
|
@@ -63,8 +111,6 @@ class ChatBot(object):
|
|
63
111
|
# NOTE: 'xx' is the language code for a multi-language model
|
64
112
|
self.nlp = spacy.blank(self.tagger.language.ISO_639_1)
|
65
113
|
|
66
|
-
self.logger = kwargs.get('logger', logging.getLogger(__name__))
|
67
|
-
|
68
114
|
# Allow the bot to save input it receives so that it can learn
|
69
115
|
self.read_only = kwargs.get('read_only', False)
|
70
116
|
|
chatterbot/comparisons.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
This module contains various text-comparison algorithms
|
3
3
|
designed to compare one statement to another.
|
4
4
|
"""
|
5
|
-
from chatterbot import
|
5
|
+
from chatterbot.utils import get_model_for_language
|
6
6
|
from difflib import SequenceMatcher
|
7
7
|
import spacy
|
8
8
|
|
@@ -100,12 +100,7 @@ class SpacySimilarity(Comparator):
|
|
100
100
|
def __init__(self, language):
|
101
101
|
super().__init__(language)
|
102
102
|
|
103
|
-
|
104
|
-
model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[self.language]
|
105
|
-
except KeyError as e:
|
106
|
-
raise KeyError(
|
107
|
-
f'Spacy model is not available for language {self.language}'
|
108
|
-
) from e
|
103
|
+
model = get_model_for_language(language)
|
109
104
|
|
110
105
|
# Disable the Named Entity Recognition (NER) component because it is not necessary
|
111
106
|
self.nlp = spacy.load(model, exclude=['ner'])
|
@@ -157,12 +152,7 @@ class JaccardSimilarity(Comparator):
|
|
157
152
|
def __init__(self, language):
|
158
153
|
super().__init__(language)
|
159
154
|
|
160
|
-
|
161
|
-
model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[self.language]
|
162
|
-
except KeyError as e:
|
163
|
-
raise KeyError(
|
164
|
-
f'Spacy model is not available for language {self.language}'
|
165
|
-
) from e
|
155
|
+
model = get_model_for_language(language)
|
166
156
|
|
167
157
|
# Disable the Named Entity Recognition (NER) component because it is not necessary
|
168
158
|
self.nlp = spacy.load(model, exclude=['ner'])
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from sqlalchemy import Table, Column, Integer, String, DateTime, ForeignKey
|
2
|
-
from sqlalchemy.orm import relationship
|
2
|
+
from sqlalchemy.orm import relationship, declarative_base
|
3
3
|
from sqlalchemy.sql import func
|
4
|
-
from sqlalchemy.ext.declarative import declared_attr
|
4
|
+
from sqlalchemy.ext.declarative import declared_attr
|
5
5
|
|
6
6
|
from chatterbot.conversation import StatementMixin
|
7
7
|
from chatterbot import constants
|
@@ -1,4 +1,8 @@
|
|
1
1
|
from chatterbot.logic import LogicAdapter
|
2
|
+
from chatterbot.conversation import Statement
|
3
|
+
from chatterbot import languages
|
4
|
+
from chatterbot.utils import get_model_for_language
|
5
|
+
import spacy
|
2
6
|
|
3
7
|
|
4
8
|
class SpecificResponseAdapter(LogicAdapter):
|
@@ -8,30 +12,65 @@ class SpecificResponseAdapter(LogicAdapter):
|
|
8
12
|
:kwargs:
|
9
13
|
* *input_text* (``str``) --
|
10
14
|
The input text that triggers this logic adapter.
|
11
|
-
* *output_text* (``str``) --
|
15
|
+
* *output_text* (``str`` or ``function``) --
|
12
16
|
The output text returned by this logic adapter.
|
17
|
+
If a function is provided, it should return a string.
|
13
18
|
"""
|
14
19
|
|
15
20
|
def __init__(self, chatbot, **kwargs):
|
16
21
|
super().__init__(chatbot, **kwargs)
|
17
|
-
from chatterbot.conversation import Statement
|
18
22
|
|
19
23
|
self.input_text = kwargs.get('input_text')
|
20
24
|
|
21
|
-
|
22
|
-
|
25
|
+
self.matcher = None
|
26
|
+
|
27
|
+
if MatcherClass := kwargs.get('matcher'):
|
28
|
+
language = kwargs.get('language', languages.ENG)
|
29
|
+
|
30
|
+
self.nlp = self._initialize_nlp(language)
|
31
|
+
|
32
|
+
self.matcher = MatcherClass(self.nlp.vocab)
|
33
|
+
|
34
|
+
self.matcher.add('SpecificResponse', [self.input_text])
|
35
|
+
|
36
|
+
self._output_text = kwargs.get('output_text')
|
37
|
+
|
38
|
+
def _initialize_nlp(self, language):
|
39
|
+
model = get_model_for_language(language)
|
40
|
+
|
41
|
+
return spacy.load(model)
|
23
42
|
|
24
43
|
def can_process(self, statement):
|
25
|
-
if
|
44
|
+
if self.matcher:
|
45
|
+
doc = self.nlp(statement.text)
|
46
|
+
matches = self.matcher(doc)
|
47
|
+
|
48
|
+
if matches:
|
49
|
+
return True
|
50
|
+
elif statement.text == self.input_text:
|
26
51
|
return True
|
27
52
|
|
28
53
|
return False
|
29
54
|
|
30
55
|
def process(self, statement, additional_response_selection_parameters=None):
|
31
56
|
|
32
|
-
if
|
33
|
-
|
57
|
+
if callable(self._output_text):
|
58
|
+
response_statement = Statement(text=self._output_text())
|
59
|
+
else:
|
60
|
+
response_statement = Statement(text=self._output_text)
|
61
|
+
|
62
|
+
if self.matcher:
|
63
|
+
doc = self.nlp(statement.text)
|
64
|
+
matches = self.matcher(doc)
|
65
|
+
|
66
|
+
if matches:
|
67
|
+
response_statement.confidence = 1
|
68
|
+
else:
|
69
|
+
response_statement.confidence = 0
|
70
|
+
|
71
|
+
elif statement.text == self.input_text:
|
72
|
+
response_statement.confidence = 1
|
34
73
|
else:
|
35
|
-
|
74
|
+
response_statement.confidence = 0
|
36
75
|
|
37
|
-
return
|
76
|
+
return response_statement
|
chatterbot/logic/time_adapter.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
from datetime import datetime
|
2
|
-
from chatterbot import
|
2
|
+
from chatterbot import languages
|
3
3
|
from chatterbot.logic import LogicAdapter
|
4
4
|
from chatterbot.conversation import Statement
|
5
|
+
from chatterbot.utils import get_model_for_language
|
5
6
|
import spacy
|
6
7
|
|
7
8
|
|
@@ -36,12 +37,7 @@ class TimeLogicAdapter(LogicAdapter):
|
|
36
37
|
|
37
38
|
language = kwargs.get('language', languages.ENG)
|
38
39
|
|
39
|
-
|
40
|
-
model = constants.DEFAULT_LANGUAGE_TO_SPACY_MODEL_MAP[language]
|
41
|
-
except KeyError as e:
|
42
|
-
raise KeyError(
|
43
|
-
f'Spacy model is not available for language {language}'
|
44
|
-
) from e
|
40
|
+
model = get_model_for_language(language)
|
45
41
|
|
46
42
|
self.nlp = spacy.load(model)
|
47
43
|
|
@@ -158,7 +158,8 @@ class UnitConversion(LogicAdapter):
|
|
158
158
|
response = func(p)
|
159
159
|
if response.confidence == 1.0:
|
160
160
|
break
|
161
|
-
except Exception:
|
161
|
+
except Exception as e:
|
162
|
+
self.chatbot.logger.warning('Error during UnitConversion: {}'.format(str(e)))
|
162
163
|
response.confidence = 0.0
|
163
|
-
|
164
|
-
|
164
|
+
|
165
|
+
return response
|
chatterbot/storage/__init__.py
CHANGED
@@ -2,6 +2,7 @@ from chatterbot.storage.storage_adapter import StorageAdapter
|
|
2
2
|
from chatterbot.storage.django_storage import DjangoStorageAdapter
|
3
3
|
from chatterbot.storage.mongodb import MongoDatabaseAdapter
|
4
4
|
from chatterbot.storage.sql_storage import SQLStorageAdapter
|
5
|
+
from chatterbot.storage.redis import RedisVectorStorageAdapter
|
5
6
|
|
6
7
|
|
7
8
|
__all__ = (
|
@@ -9,4 +10,5 @@ __all__ = (
|
|
9
10
|
'DjangoStorageAdapter',
|
10
11
|
'MongoDatabaseAdapter',
|
11
12
|
'SQLStorageAdapter',
|
13
|
+
'RedisVectorStorageAdapter',
|
12
14
|
)
|
@@ -0,0 +1,390 @@
|
|
1
|
+
from datetime import datetime
|
2
|
+
from chatterbot.storage import StorageAdapter
|
3
|
+
from chatterbot.conversation import Statement as StatementObject
|
4
|
+
|
5
|
+
|
6
|
+
# TODO: This list may not be exhaustive.
|
7
|
+
# Is there a full list of characters reserved by redis?
|
8
|
+
REDIS_ESCAPE_CHARACTERS = {
|
9
|
+
'\\': '\\\\',
|
10
|
+
':': '\\:',
|
11
|
+
'|': '\\|',
|
12
|
+
'%': '\\%',
|
13
|
+
'!': '\\!',
|
14
|
+
'-': '\\-',
|
15
|
+
}
|
16
|
+
|
17
|
+
REDIS_TRANSLATION_TABLE = str.maketrans(REDIS_ESCAPE_CHARACTERS)
|
18
|
+
|
19
|
+
def _escape_redis_special_characters(text):
|
20
|
+
"""
|
21
|
+
Escape special characters in a string that are used in redis queries.
|
22
|
+
"""
|
23
|
+
return text.translate(REDIS_TRANSLATION_TABLE)
|
24
|
+
|
25
|
+
|
26
|
+
class RedisVectorStorageAdapter(StorageAdapter):
|
27
|
+
"""
|
28
|
+
.. warning:: BETA feature (Released March, 2025): this storage adapter is new
|
29
|
+
and experimental. Its functionality and default parameters might change
|
30
|
+
in the future and its behavior has not yet been finalized.
|
31
|
+
|
32
|
+
The RedisVectorStorageAdapter allows ChatterBot to store conversation
|
33
|
+
data in a redis instance.
|
34
|
+
|
35
|
+
All parameters are optional, by default a redis instance on localhost is assumed.
|
36
|
+
|
37
|
+
:keyword database_uri: eg: redis://localhost:6379/0',
|
38
|
+
The database_uri can be specified to choose a redis instance.
|
39
|
+
:type database_uri: str
|
40
|
+
"""
|
41
|
+
|
42
|
+
class RedisMetaDataType:
|
43
|
+
"""
|
44
|
+
Subclass for redis config metadata type enumerator.
|
45
|
+
"""
|
46
|
+
TAG = 'tag'
|
47
|
+
TEXT = 'text'
|
48
|
+
NUMERIC = 'numeric'
|
49
|
+
|
50
|
+
def __init__(self, **kwargs):
|
51
|
+
super().__init__(**kwargs)
|
52
|
+
from chatterbot.vectorstores import RedisVectorStore
|
53
|
+
from langchain_redis import RedisConfig # RedisVectorStore
|
54
|
+
from langchain_huggingface import HuggingFaceEmbeddings
|
55
|
+
|
56
|
+
self.database_uri = kwargs.get('database_uri', 'redis://localhost:6379/0')
|
57
|
+
|
58
|
+
config = RedisConfig(
|
59
|
+
index_name='chatterbot',
|
60
|
+
redis_url=self.database_uri,
|
61
|
+
content_field='in_response_to',
|
62
|
+
metadata_schema=[
|
63
|
+
{
|
64
|
+
'name': 'conversation',
|
65
|
+
'type': self.RedisMetaDataType.TAG,
|
66
|
+
},
|
67
|
+
{
|
68
|
+
'name': 'text',
|
69
|
+
'type': self.RedisMetaDataType.TEXT,
|
70
|
+
},
|
71
|
+
{
|
72
|
+
'name': 'created_at',
|
73
|
+
'type': self.RedisMetaDataType.NUMERIC,
|
74
|
+
},
|
75
|
+
{
|
76
|
+
'name': 'persona',
|
77
|
+
'type': self.RedisMetaDataType.TEXT,
|
78
|
+
},
|
79
|
+
{
|
80
|
+
'name': 'tags',
|
81
|
+
'type': self.RedisMetaDataType.TAG,
|
82
|
+
# 'separator': '|'
|
83
|
+
},
|
84
|
+
],
|
85
|
+
)
|
86
|
+
|
87
|
+
# TODO should this call from_existing_index if connecting to
|
88
|
+
# a redis instance that already contains data?
|
89
|
+
|
90
|
+
self.logger.info('Loading HuggingFace embeddings')
|
91
|
+
|
92
|
+
# TODO: Research different embeddings
|
93
|
+
# https://python.langchain.com/docs/integrations/vectorstores/mongodb_atlas/#initialization
|
94
|
+
|
95
|
+
embeddings = HuggingFaceEmbeddings(
|
96
|
+
model_name='sentence-transformers/all-mpnet-base-v2'
|
97
|
+
)
|
98
|
+
|
99
|
+
self.logger.info('Creating Redis Vector Store')
|
100
|
+
|
101
|
+
self.vector_store = RedisVectorStore(embeddings, config=config)
|
102
|
+
|
103
|
+
def get_statement_model(self):
|
104
|
+
"""
|
105
|
+
Return the statement model.
|
106
|
+
"""
|
107
|
+
from langchain_core.documents import Document
|
108
|
+
|
109
|
+
return Document
|
110
|
+
|
111
|
+
def model_to_object(self, document):
|
112
|
+
|
113
|
+
in_response_to = document.page_content
|
114
|
+
|
115
|
+
# If the value is an empty string, set it to None
|
116
|
+
# to match the expected type (the vector store does
|
117
|
+
# not use null values)
|
118
|
+
if in_response_to == '':
|
119
|
+
in_response_to = None
|
120
|
+
|
121
|
+
values = {
|
122
|
+
'in_response_to': in_response_to,
|
123
|
+
}
|
124
|
+
|
125
|
+
if document.id:
|
126
|
+
values['id'] = document.id
|
127
|
+
|
128
|
+
values.update(document.metadata)
|
129
|
+
|
130
|
+
tags = values['tags']
|
131
|
+
values['tags'] = list(set(tags.split('|') if tags else []))
|
132
|
+
|
133
|
+
return StatementObject(**values)
|
134
|
+
|
135
|
+
def count(self):
|
136
|
+
"""
|
137
|
+
Return the number of entries in the database.
|
138
|
+
"""
|
139
|
+
|
140
|
+
'''
|
141
|
+
TODO
|
142
|
+
faiss_vector_store = FAISS(
|
143
|
+
embedding_function=embedding_function,
|
144
|
+
index=IndexFlatL2(embedding_size),
|
145
|
+
docstore=InMemoryDocstore(),
|
146
|
+
index_to_docstore_id={}
|
147
|
+
)
|
148
|
+
doc_count = faiss_vector_store.index.ntotal
|
149
|
+
'''
|
150
|
+
|
151
|
+
client = self.vector_store.index.client
|
152
|
+
return client.dbsize()
|
153
|
+
|
154
|
+
def remove(self, statement):
|
155
|
+
"""
|
156
|
+
Removes the statement that matches the input text.
|
157
|
+
Removes any responses from statements where the response text matches
|
158
|
+
the input text.
|
159
|
+
"""
|
160
|
+
self.vector_store.delete(ids=[statement.id.split(':')[1]])
|
161
|
+
|
162
|
+
def filter(self, page_size=4, **kwargs):
|
163
|
+
"""
|
164
|
+
Returns a list of objects from the database.
|
165
|
+
The kwargs parameter can contain any number
|
166
|
+
of attributes. Only objects which contain all
|
167
|
+
listed attributes and in which all values match
|
168
|
+
for all listed attributes will be returned.
|
169
|
+
|
170
|
+
kwargs:
|
171
|
+
- conversation
|
172
|
+
- persona
|
173
|
+
- tags
|
174
|
+
- in_response_to
|
175
|
+
- text
|
176
|
+
- exclude_text
|
177
|
+
- exclude_text_words
|
178
|
+
- persona_not_startswith
|
179
|
+
- search_in_response_to_contains
|
180
|
+
- order_by
|
181
|
+
"""
|
182
|
+
from redisvl.query.filter import Tag, Text
|
183
|
+
|
184
|
+
# https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/query_syntax/
|
185
|
+
filter_condition = None
|
186
|
+
|
187
|
+
if 'in_response_to' in kwargs:
|
188
|
+
filter_condition = Text('in_response_to') == kwargs['in_response_to']
|
189
|
+
|
190
|
+
if 'conversation' in kwargs:
|
191
|
+
query = Tag('conversation') == kwargs['conversation']
|
192
|
+
if filter_condition:
|
193
|
+
filter_condition &= query
|
194
|
+
else:
|
195
|
+
filter_condition = query
|
196
|
+
|
197
|
+
if 'persona' in kwargs:
|
198
|
+
query = Tag('persona') == kwargs['persona']
|
199
|
+
if filter_condition:
|
200
|
+
filter_condition &= query
|
201
|
+
else:
|
202
|
+
filter_condition = query
|
203
|
+
|
204
|
+
if 'tags' in kwargs:
|
205
|
+
query = Tag('tags') == kwargs['tags']
|
206
|
+
if filter_condition:
|
207
|
+
filter_condition &= query
|
208
|
+
else:
|
209
|
+
filter_condition = query
|
210
|
+
|
211
|
+
if 'exclude_text' in kwargs:
|
212
|
+
query = Text('text') != '|'.join([
|
213
|
+
f'%%{text}%%' for text in kwargs['exclude_text']
|
214
|
+
])
|
215
|
+
if filter_condition:
|
216
|
+
filter_condition &= query
|
217
|
+
else:
|
218
|
+
filter_condition = query
|
219
|
+
|
220
|
+
if 'exclude_text_words' in kwargs:
|
221
|
+
_query = '|'.join([
|
222
|
+
f'%%{text}%%' for text in kwargs['exclude_text_words']
|
223
|
+
])
|
224
|
+
query = Text('text') % f'-({ _query })'
|
225
|
+
if filter_condition:
|
226
|
+
filter_condition &= query
|
227
|
+
else:
|
228
|
+
filter_condition = query
|
229
|
+
|
230
|
+
if 'persona_not_startswith' in kwargs:
|
231
|
+
_query = _escape_redis_special_characters(kwargs['persona_not_startswith'])
|
232
|
+
query = Text('persona') % f'-(%%{_query}%%)'
|
233
|
+
if filter_condition:
|
234
|
+
filter_condition &= query
|
235
|
+
else:
|
236
|
+
filter_condition = query
|
237
|
+
|
238
|
+
if 'text' in kwargs:
|
239
|
+
_query = _escape_redis_special_characters(kwargs['text'])
|
240
|
+
query = Text('text') % '|'.join([f'%%{_q}%%' for _q in _query.split()])
|
241
|
+
if filter_condition:
|
242
|
+
filter_condition &= query
|
243
|
+
else:
|
244
|
+
filter_condition = query
|
245
|
+
|
246
|
+
ordering = kwargs.get('order_by', None)
|
247
|
+
|
248
|
+
if ordering:
|
249
|
+
ordering = ','.join(ordering)
|
250
|
+
|
251
|
+
if 'search_in_response_to_contains' in kwargs:
|
252
|
+
_search_text = kwargs.get('search_in_response_to_contains', '')
|
253
|
+
|
254
|
+
# TODO similarity_search_with_score
|
255
|
+
documents = self.vector_store.similarity_search(
|
256
|
+
_search_text,
|
257
|
+
k=page_size, # The number of results to return
|
258
|
+
return_all=True, # Include the full document with IDs
|
259
|
+
filter=filter_condition,
|
260
|
+
sort_by=ordering
|
261
|
+
)
|
262
|
+
else:
|
263
|
+
documents = self.vector_store.query_search(
|
264
|
+
k=page_size,
|
265
|
+
filter=filter_condition,
|
266
|
+
sort_by=ordering
|
267
|
+
)
|
268
|
+
|
269
|
+
return [self.model_to_object(document) for document in documents]
|
270
|
+
|
271
|
+
def create(
|
272
|
+
self,
|
273
|
+
text,
|
274
|
+
in_response_to=None,
|
275
|
+
tags=None,
|
276
|
+
**kwargs
|
277
|
+
):
|
278
|
+
"""
|
279
|
+
Creates a new statement matching the keyword arguments specified.
|
280
|
+
Returns the created statement.
|
281
|
+
"""
|
282
|
+
# from langchain_community.vectorstores.redis.constants import REDIS_TAG_SEPARATOR
|
283
|
+
|
284
|
+
_default_date = datetime.now()
|
285
|
+
|
286
|
+
metadata = {
|
287
|
+
'text': text,
|
288
|
+
'category': kwargs.get('category', ''),
|
289
|
+
# NOTE: `created_at` must have a valid numeric value or results will
|
290
|
+
# not be returned for similarity_search for some reason
|
291
|
+
'created_at': kwargs.get('created_at') or int(_default_date.strftime('%y%m%d')),
|
292
|
+
'tags': '|'.join(tags) if tags else '',
|
293
|
+
'conversation': kwargs.get('conversation', ''),
|
294
|
+
'persona': kwargs.get('persona', ''),
|
295
|
+
}
|
296
|
+
|
297
|
+
ids = self.vector_store.add_texts([in_response_to or ''], [metadata])
|
298
|
+
|
299
|
+
metadata['created_at'] = _default_date
|
300
|
+
metadata['tags'] = tags or []
|
301
|
+
metadata.pop('text')
|
302
|
+
statement = StatementObject(
|
303
|
+
id=ids[0],
|
304
|
+
text=text,
|
305
|
+
**metadata
|
306
|
+
)
|
307
|
+
return statement
|
308
|
+
|
309
|
+
def create_many(self, statements):
|
310
|
+
"""
|
311
|
+
Creates multiple statement entries.
|
312
|
+
"""
|
313
|
+
Document = self.get_statement_model()
|
314
|
+
documents = [
|
315
|
+
Document(
|
316
|
+
page_content=statement.in_response_to or '',
|
317
|
+
metadata={
|
318
|
+
'text': statement.text,
|
319
|
+
'conversation': statement.conversation or '',
|
320
|
+
'created_at': int(statement.created_at.strftime('%y%m%d')),
|
321
|
+
'persona': statement.persona or '',
|
322
|
+
'tags': '|'.join(statement.tags) if statement.tags else '',
|
323
|
+
}
|
324
|
+
) for statement in statements
|
325
|
+
]
|
326
|
+
|
327
|
+
self.logger.info('Adding documents to the vector store')
|
328
|
+
|
329
|
+
self.vector_store.add_documents(documents)
|
330
|
+
|
331
|
+
def update(self, statement):
|
332
|
+
"""
|
333
|
+
Modifies an entry in the database.
|
334
|
+
Creates an entry if one does not exist.
|
335
|
+
"""
|
336
|
+
metadata = {
|
337
|
+
'text': statement.text,
|
338
|
+
'conversation': statement.conversation or '',
|
339
|
+
'created_at': int(statement.created_at.strftime('%y%m%d')),
|
340
|
+
'persona': statement.persona or '',
|
341
|
+
'tags': '|'.join(statement.tags) if statement.tags else '',
|
342
|
+
}
|
343
|
+
|
344
|
+
Document = self.get_statement_model()
|
345
|
+
document = Document(
|
346
|
+
page_content=statement.in_response_to or '',
|
347
|
+
metadata=metadata,
|
348
|
+
)
|
349
|
+
|
350
|
+
if statement.id:
|
351
|
+
self.vector_store.add_texts(
|
352
|
+
[document.page_content], [metadata], keys=[statement.id.split(':')[1]]
|
353
|
+
)
|
354
|
+
else:
|
355
|
+
self.vector_store.add_documents([document])
|
356
|
+
|
357
|
+
def get_random(self):
|
358
|
+
"""
|
359
|
+
Returns a random statement from the database.
|
360
|
+
"""
|
361
|
+
client = self.vector_store.index.client
|
362
|
+
|
363
|
+
random_key = client.randomkey()
|
364
|
+
|
365
|
+
if random_key:
|
366
|
+
random_id = random_key.decode().split(':')[1]
|
367
|
+
|
368
|
+
documents = self.vector_store.get_by_ids([random_id])
|
369
|
+
|
370
|
+
if documents:
|
371
|
+
return self.model_to_object(documents[0])
|
372
|
+
|
373
|
+
raise self.EmptyDatabaseException()
|
374
|
+
|
375
|
+
def drop(self):
|
376
|
+
"""
|
377
|
+
Remove all existing documents from the database.
|
378
|
+
"""
|
379
|
+
index_name = self.vector_store.config.index_name
|
380
|
+
client = self.vector_store.index.client
|
381
|
+
|
382
|
+
for key in client.scan_iter(f'{index_name}:*'):
|
383
|
+
# self.vector_store.index.drop_keys(key)
|
384
|
+
client.delete(key)
|
385
|
+
|
386
|
+
# Commenting this out for now because there is no step
|
387
|
+
# to recreate the index after it is dropped (really what
|
388
|
+
# we want is to delete all the keys in the index, but
|
389
|
+
# keep the index itself)
|
390
|
+
# self.vector_store.index.delete(drop=True)
|