ChatterBot 1.2.10__tar.gz → 1.2.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. {chatterbot-1.2.10 → chatterbot-1.2.11}/ChatterBot.egg-info/PKG-INFO +1 -1
  2. {chatterbot-1.2.10 → chatterbot-1.2.11}/ChatterBot.egg-info/SOURCES.txt +2 -0
  3. {chatterbot-1.2.10 → chatterbot-1.2.11}/PKG-INFO +1 -1
  4. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/__init__.py +1 -1
  5. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/storage/sql_storage.py +153 -112
  6. chatterbot-1.2.11/tests/test_connection_pool.py +268 -0
  7. chatterbot-1.2.11/tests/test_poc_vulnerability.py +152 -0
  8. {chatterbot-1.2.10 → chatterbot-1.2.11}/ChatterBot.egg-info/dependency_links.txt +0 -0
  9. {chatterbot-1.2.10 → chatterbot-1.2.11}/ChatterBot.egg-info/requires.txt +0 -0
  10. {chatterbot-1.2.10 → chatterbot-1.2.11}/ChatterBot.egg-info/top_level.txt +0 -0
  11. {chatterbot-1.2.10 → chatterbot-1.2.11}/LICENSE +0 -0
  12. {chatterbot-1.2.10 → chatterbot-1.2.11}/README.md +0 -0
  13. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/__main__.py +0 -0
  14. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/adapters.py +0 -0
  15. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/chatterbot.py +0 -0
  16. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/comparisons.py +0 -0
  17. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/components.py +0 -0
  18. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/constants.py +0 -0
  19. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/conversation.py +0 -0
  20. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/corpus.py +0 -0
  21. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/exceptions.py +0 -0
  22. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/__init__.py +0 -0
  23. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/__init__.py +0 -0
  24. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/abstract_models.py +0 -0
  25. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/admin.py +0 -0
  26. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/apps.py +0 -0
  27. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0001_initial.py +0 -0
  28. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0002_statement_extra_data.py +0 -0
  29. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0003_change_occurrence_default.py +0 -0
  30. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0004_rename_in_response_to.py +0 -0
  31. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0005_statement_created_at.py +0 -0
  32. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0006_create_conversation.py +0 -0
  33. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0007_response_created_at.py +0 -0
  34. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0008_update_conversations.py +0 -0
  35. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0009_tags.py +0 -0
  36. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0010_statement_text.py +0 -0
  37. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0011_blank_extra_data.py +0 -0
  38. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0012_statement_created_at.py +0 -0
  39. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0013_change_conversations.py +0 -0
  40. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0014_remove_statement_extra_data.py +0 -0
  41. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0015_statement_persona.py +0 -0
  42. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0016_statement_stemmed_text.py +0 -0
  43. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0017_tags_unique.py +0 -0
  44. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0018_text_max_length.py +0 -0
  45. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0019_alter_statement_id_alter_tag_id_and_more.py +0 -0
  46. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/0020_alter_statement_conversation_and_more.py +0 -0
  47. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/migrations/__init__.py +0 -0
  48. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/model_admin.py +0 -0
  49. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/models.py +0 -0
  50. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/django_chatterbot/settings.py +0 -0
  51. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/sqlalchemy_app/__init__.py +0 -0
  52. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/ext/sqlalchemy_app/models.py +0 -0
  53. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/filters.py +0 -0
  54. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/languages.py +0 -0
  55. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/llm.py +0 -0
  56. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/logic/__init__.py +0 -0
  57. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/logic/best_match.py +0 -0
  58. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/logic/logic_adapter.py +0 -0
  59. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/logic/mathematical_evaluation.py +0 -0
  60. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/logic/specific_response.py +0 -0
  61. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/logic/time_adapter.py +0 -0
  62. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/logic/unit_conversion.py +0 -0
  63. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/parsing.py +0 -0
  64. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/preprocessors.py +0 -0
  65. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/response_selection.py +0 -0
  66. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/search.py +0 -0
  67. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/storage/__init__.py +0 -0
  68. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/storage/django_storage.py +0 -0
  69. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/storage/mongodb.py +0 -0
  70. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/storage/redis.py +0 -0
  71. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/storage/storage_adapter.py +0 -0
  72. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/tagging.py +0 -0
  73. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/trainers.py +0 -0
  74. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/utils.py +0 -0
  75. {chatterbot-1.2.10 → chatterbot-1.2.11}/chatterbot/vectorstores.py +0 -0
  76. {chatterbot-1.2.10 → chatterbot-1.2.11}/pyproject.toml +0 -0
  77. {chatterbot-1.2.10 → chatterbot-1.2.11}/setup.cfg +0 -0
  78. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_adapter_validation.py +0 -0
  79. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_benchmarks.py +0 -0
  80. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_chatbot.py +0 -0
  81. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_cli.py +0 -0
  82. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_comparisons.py +0 -0
  83. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_conversations.py +0 -0
  84. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_corpus.py +0 -0
  85. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_examples.py +0 -0
  86. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_filters.py +0 -0
  87. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_initialization.py +0 -0
  88. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_languages.py +0 -0
  89. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_parsing.py +0 -0
  90. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_preprocessors.py +0 -0
  91. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_response_selection.py +0 -0
  92. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_search.py +0 -0
  93. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_tagging.py +0 -0
  94. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_turing.py +0 -0
  95. {chatterbot-1.2.10 → chatterbot-1.2.11}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ChatterBot
3
- Version: 1.2.10
3
+ Version: 1.2.11
4
4
  Summary: ChatterBot is a machine learning, conversational dialog engine
5
5
  Author: Gunther Cox
6
6
  License-Expression: BSD-3-Clause
@@ -77,6 +77,7 @@ tests/test_benchmarks.py
77
77
  tests/test_chatbot.py
78
78
  tests/test_cli.py
79
79
  tests/test_comparisons.py
80
+ tests/test_connection_pool.py
80
81
  tests/test_conversations.py
81
82
  tests/test_corpus.py
82
83
  tests/test_examples.py
@@ -84,6 +85,7 @@ tests/test_filters.py
84
85
  tests/test_initialization.py
85
86
  tests/test_languages.py
86
87
  tests/test_parsing.py
88
+ tests/test_poc_vulnerability.py
87
89
  tests/test_preprocessors.py
88
90
  tests/test_response_selection.py
89
91
  tests/test_search.py
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ChatterBot
3
- Version: 1.2.10
3
+ Version: 1.2.11
4
4
  Summary: ChatterBot is a machine learning, conversational dialog engine
5
5
  Author: Gunther Cox
6
6
  License-Expression: BSD-3-Clause
@@ -4,7 +4,7 @@ ChatterBot is a machine learning, conversational dialog engine.
4
4
  from .chatterbot import ChatBot
5
5
 
6
6
 
7
- __version__ = '1.2.10'
7
+ __version__ = '1.2.11'
8
8
 
9
9
  __all__ = (
10
10
  'ChatBot',
@@ -23,7 +23,7 @@ class SQLStorageAdapter(StorageAdapter):
23
23
  from sqlalchemy import create_engine, inspect, event
24
24
  from sqlalchemy import Index
25
25
  from sqlalchemy.engine import Engine
26
- from sqlalchemy.orm import sessionmaker
26
+ from sqlalchemy.orm import sessionmaker, scoped_session
27
27
 
28
28
  self.database_uri = kwargs.get('database_uri', False)
29
29
 
@@ -35,7 +35,10 @@ class SQLStorageAdapter(StorageAdapter):
35
35
  if not self.database_uri:
36
36
  self.database_uri = 'sqlite:///db.sqlite3'
37
37
 
38
- self.engine = create_engine(self.database_uri)
38
+ # Configure connection pool with safe defaults to prevent exhaustion
39
+ # Note: SQLite uses SingletonThreadPool which doesn't support these params
40
+ # PostgreSQL, MySQL, etc. use QueuePool which does support them
41
+ pool_config = {}
39
42
 
40
43
  if self.database_uri.startswith('sqlite://'):
41
44
 
@@ -66,6 +69,23 @@ class SQLStorageAdapter(StorageAdapter):
66
69
  cursor.execute('PRAGMA synchronous=NORMAL')
67
70
  cursor.close()
68
71
 
72
+ else:
73
+ # Only apply pool configuration for databases that support QueuePool
74
+ # pool_size: Maximum persistent connections (10)
75
+ # max_overflow: Additional connections during peak load (20)
76
+ # pool_timeout: Seconds to wait for connection before error (30)
77
+ # pool_recycle: Recycle connections after 1 hour to prevent stale connections
78
+ # pool_pre_ping: Test connections before using to detect disconnects
79
+ pool_config = {
80
+ 'pool_size': kwargs.get('pool_size', 10),
81
+ 'max_overflow': kwargs.get('max_overflow', 20),
82
+ 'pool_timeout': kwargs.get('pool_timeout', 30),
83
+ 'pool_recycle': kwargs.get('pool_recycle', 3600),
84
+ 'pool_pre_ping': kwargs.get('pool_pre_ping', True),
85
+ }
86
+
87
+ self.engine = create_engine(self.database_uri, **pool_config)
88
+
69
89
  if not inspect(self.engine).has_table('statement'):
70
90
  self.create_database()
71
91
 
@@ -91,7 +111,10 @@ class SQLStorageAdapter(StorageAdapter):
91
111
 
92
112
  search_in_response_to_index.create(bind=self.engine)
93
113
 
94
- self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)
114
+ # Use a scoped session for thread-safe session management
115
+ # This provides thread-local session storage to prevent session sharing across threads
116
+ session_factory = sessionmaker(bind=self.engine, expire_on_commit=True)
117
+ self.Session = scoped_session(session_factory)
95
118
 
96
119
  def get_statement_model(self):
97
120
  """
@@ -119,9 +142,11 @@ class SQLStorageAdapter(StorageAdapter):
119
142
  Statement = self.get_model('statement')
120
143
 
121
144
  session = self.Session()
122
- statement_count = session.query(Statement).count()
123
- session.close()
124
- return statement_count
145
+ try:
146
+ statement_count = session.query(Statement).count()
147
+ return statement_count
148
+ finally:
149
+ session.close()
125
150
 
126
151
  def remove(self, statement_text):
127
152
  """
@@ -131,13 +156,14 @@ class SQLStorageAdapter(StorageAdapter):
131
156
  """
132
157
  Statement = self.get_model('statement')
133
158
  session = self.Session()
159
+ try:
160
+ query = session.query(Statement).filter_by(text=statement_text)
161
+ record = query.first()
134
162
 
135
- query = session.query(Statement).filter_by(text=statement_text)
136
- record = query.first()
137
-
138
- session.delete(record)
139
- session.commit()
140
- session.close()
163
+ session.delete(record)
164
+ session.commit()
165
+ finally:
166
+ session.close()
141
167
 
142
168
  def filter(self, **kwargs):
143
169
  """
@@ -152,8 +178,6 @@ class SQLStorageAdapter(StorageAdapter):
152
178
  Statement = self.get_model('statement')
153
179
  Tag = self.get_model('tag')
154
180
 
155
- session = self.Session()
156
-
157
181
  page_size = kwargs.pop('page_size', 1000)
158
182
  order_by = kwargs.pop('order_by', None)
159
183
  tags = kwargs.pop('tags', [])
@@ -167,65 +191,69 @@ class SQLStorageAdapter(StorageAdapter):
167
191
  if isinstance(tags, str):
168
192
  tags = [tags]
169
193
 
170
- if len(kwargs) == 0:
171
- statements = session.query(Statement).filter()
172
- else:
173
- statements = session.query(Statement).filter_by(**kwargs)
174
-
175
- if tags:
176
- statements = statements.join(Statement.tags).filter(
177
- Tag.name.in_(tags)
178
- )
179
-
180
- if exclude_text:
181
- statements = statements.filter(
182
- ~Statement.text.in_(exclude_text)
183
- )
194
+ # Use context manager to ensure session cleanup even if generator is partially consumed
195
+ session = self.Session()
196
+ try:
197
+ if len(kwargs) == 0:
198
+ statements = session.query(Statement).filter()
199
+ else:
200
+ statements = session.query(Statement).filter_by(**kwargs)
201
+
202
+ if tags:
203
+ statements = statements.join(Statement.tags).filter(
204
+ Tag.name.in_(tags)
205
+ )
184
206
 
185
- if exclude_text_words:
186
- or_word_query = [
187
- Statement.text.ilike('%' + word + '%') for word in exclude_text_words
188
- ]
189
- statements = statements.filter(
190
- ~or_(*or_word_query)
191
- )
207
+ if exclude_text:
208
+ statements = statements.filter(
209
+ ~Statement.text.in_(exclude_text)
210
+ )
192
211
 
193
- if persona_not_startswith:
194
- statements = statements.filter(
195
- ~Statement.persona.startswith('bot:')
196
- )
212
+ if exclude_text_words:
213
+ or_word_query = [
214
+ Statement.text.ilike('%' + word + '%') for word in exclude_text_words
215
+ ]
216
+ statements = statements.filter(
217
+ ~or_(*or_word_query)
218
+ )
197
219
 
198
- if search_text_contains:
199
- or_query = [
200
- Statement.search_text.contains(word) for word in search_text_contains.split(' ')
201
- ]
202
- statements = statements.filter(
203
- or_(*or_query)
204
- )
220
+ if persona_not_startswith:
221
+ statements = statements.filter(
222
+ ~Statement.persona.startswith('bot:')
223
+ )
205
224
 
206
- if search_in_response_to_contains:
207
- or_query = [
208
- Statement.search_in_response_to.contains(word) for word in search_in_response_to_contains.split(' ')
209
- ]
210
- statements = statements.filter(
211
- or_(*or_query)
212
- )
225
+ if search_text_contains:
226
+ or_query = [
227
+ Statement.search_text.contains(word) for word in search_text_contains.split(' ')
228
+ ]
229
+ statements = statements.filter(
230
+ or_(*or_query)
231
+ )
213
232
 
214
- if order_by:
233
+ if search_in_response_to_contains:
234
+ or_query = [
235
+ Statement.search_in_response_to.contains(word) for word in search_in_response_to_contains.split(' ')
236
+ ]
237
+ statements = statements.filter(
238
+ or_(*or_query)
239
+ )
215
240
 
216
- if 'created_at' in order_by:
217
- index = order_by.index('created_at')
218
- order_by[index] = Statement.created_at.asc()
241
+ if order_by:
219
242
 
220
- statements = statements.order_by(*order_by)
243
+ if 'created_at' in order_by:
244
+ index = order_by.index('created_at')
245
+ order_by[index] = Statement.created_at.asc()
221
246
 
222
- total_statements = statements.count()
247
+ statements = statements.order_by(*order_by)
223
248
 
224
- for start_index in range(0, total_statements, page_size):
225
- for statement in statements.slice(start_index, start_index + page_size):
226
- yield self.model_to_object(statement)
249
+ total_statements = statements.count()
227
250
 
228
- session.close()
251
+ for start_index in range(0, total_statements, page_size):
252
+ for statement in statements.slice(start_index, start_index + page_size):
253
+ yield self.model_to_object(statement)
254
+ finally:
255
+ # Always close session, even if generator is abandoned or exception occurs
256
+ session.close()
229
257
 
230
258
  def create(
231
259
  self,
@@ -336,8 +364,11 @@ class SQLStorageAdapter(StorageAdapter):
336
364
  statement_model_object.tags.append(tag)
337
365
  create_statements.append(statement_model_object)
338
366
 
339
- session.add_all(create_statements)
340
- session.commit()
367
+ try:
368
+ session.add_all(create_statements)
369
+ session.commit()
370
+ finally:
371
+ session.close()
341
372
 
342
373
  def update(self, statement):
343
374
  """
@@ -348,49 +379,51 @@ class SQLStorageAdapter(StorageAdapter):
348
379
  Tag = self.get_model('tag')
349
380
 
350
381
  session = self.Session()
351
- record = None
382
+ try:
383
+ record = None
352
384
 
353
- if hasattr(statement, 'id') and statement.id is not None:
354
- record = session.get(Statement, statement.id)
355
- else:
356
- record = session.query(Statement).filter(
357
- Statement.text == statement.text,
358
- Statement.conversation == statement.conversation,
359
- ).first()
360
-
361
- # Create a new statement entry if one does not already exist
362
- if not record:
363
- record = Statement(
364
- text=statement.text,
365
- conversation=statement.conversation,
366
- persona=statement.persona
367
- )
385
+ if hasattr(statement, 'id') and statement.id is not None:
386
+ record = session.get(Statement, statement.id)
387
+ else:
388
+ record = session.query(Statement).filter(
389
+ Statement.text == statement.text,
390
+ Statement.conversation == statement.conversation,
391
+ ).first()
368
392
 
369
- # Update the response value
370
- record.in_response_to = statement.in_response_to
393
+ # Create a new statement entry if one does not already exist
394
+ if not record:
395
+ record = Statement(
396
+ text=statement.text,
397
+ conversation=statement.conversation,
398
+ persona=statement.persona
399
+ )
371
400
 
372
- record.created_at = statement.created_at
401
+ # Update the response value
402
+ record.in_response_to = statement.in_response_to
373
403
 
374
- if not statement.search_text:
375
- if self.raise_on_missing_search_text:
376
- raise Exception('update issued without search_text value')
404
+ record.created_at = statement.created_at
377
405
 
378
- if statement.in_response_to and not statement.search_in_response_to:
379
- if self.raise_on_missing_search_text:
380
- raise Exception('update issued without search_in_response_to value')
406
+ if not statement.search_text:
407
+ if self.raise_on_missing_search_text:
408
+ raise Exception('update issued without search_text value')
381
409
 
382
- for tag_name in statement.get_tags():
383
- tag = session.query(Tag).filter_by(name=tag_name).first()
410
+ if statement.in_response_to and not statement.search_in_response_to:
411
+ if self.raise_on_missing_search_text:
412
+ raise Exception('update issued without search_in_response_to value')
384
413
 
385
- if not tag:
386
- # Create the record
387
- tag = Tag(name=tag_name)
414
+ for tag_name in statement.get_tags():
415
+ tag = session.query(Tag).filter_by(name=tag_name).first()
388
416
 
389
- record.tags.append(tag)
417
+ if not tag:
418
+ # Create the record
419
+ tag = Tag(name=tag_name)
390
420
 
391
- session.add(record)
392
- session.commit()
393
- session.close()
421
+ record.tags.append(tag)
422
+
423
+ session.add(record)
424
+ session.commit()
425
+ finally:
426
+ session.close()
394
427
 
395
428
  def get_random(self):
396
429
  """
@@ -399,17 +432,19 @@ class SQLStorageAdapter(StorageAdapter):
399
432
  Statement = self.get_model('statement')
400
433
 
401
434
  session = self.Session()
402
- count = self.count()
403
- if count < 1:
404
- raise self.EmptyDatabaseException()
435
+ try:
436
+ count = self.count()
437
+ if count < 1:
438
+ raise self.EmptyDatabaseException()
405
439
 
406
- random_index = random.randrange(0, count)
407
- random_statement = session.query(Statement)[random_index]
440
+ random_index = random.randrange(0, count)
441
+ random_statement = session.query(Statement)[random_index]
408
442
 
409
- statement = self.model_to_object(random_statement)
443
+ statement = self.model_to_object(random_statement)
410
444
 
411
- session.close()
412
- return statement
445
+ return statement
446
+ finally:
447
+ session.close()
413
448
 
414
449
  def drop(self):
415
450
  """
@@ -419,12 +454,13 @@ class SQLStorageAdapter(StorageAdapter):
419
454
  Tag = self.get_model('tag')
420
455
 
421
456
  session = self.Session()
457
+ try:
458
+ session.query(Statement).delete()
459
+ session.query(Tag).delete()
422
460
 
423
- session.query(Statement).delete()
424
- session.query(Tag).delete()
425
-
426
- session.commit()
427
- session.close()
461
+ session.commit()
462
+ finally:
463
+ session.close()
428
464
 
429
465
  def create_database(self):
430
466
  """
@@ -438,5 +474,10 @@ class SQLStorageAdapter(StorageAdapter):
438
474
  Close the database connection and dispose of the engine.
439
475
  This ensures proper cleanup of resources.
440
476
  """
477
+ # Remove thread-local sessions from scoped_session registry
478
+ if hasattr(self, 'Session'):
479
+ self.Session.remove()
480
+
481
+ # Dispose of the connection pool
441
482
  if hasattr(self, 'engine'):
442
483
  self.engine.dispose()
@@ -0,0 +1,268 @@
1
+ """
2
+ Tests for database connection pool management and concurrency.
3
+
4
+ These tests verify that the fixes for the connection pool exhaustion
5
+ vulnerability (CVE-TBD) are working correctly.
6
+ """
7
+ import threading
8
+ import time
9
+ import unittest
10
+ from chatterbot import ChatBot
11
+ from chatterbot.trainers import ListTrainer
12
+
13
+
14
+ class ConnectionPoolTestCase(unittest.TestCase):
15
+ """
16
+ Test cases for database connection pool management.
17
+ """
18
+
19
+ def setUp(self):
20
+ """
21
+ Set up test fixtures before each test.
22
+ """
23
+ # Use in-memory SQLite for fast testing
24
+ # Note: SQLite doesn't use QueuePool, so pool params are ignored
25
+ self.chatbot = ChatBot(
26
+ 'TestBot',
27
+ database_uri='sqlite://',
28
+ )
29
+
30
+ # Train with some basic responses
31
+ trainer = ListTrainer(self.chatbot)
32
+ trainer.train([
33
+ 'Hi',
34
+ 'Hello!',
35
+ 'How are you?',
36
+ 'I am doing well.',
37
+ 'What is your name?',
38
+ 'My name is TestBot.',
39
+ ])
40
+
41
+ def tearDown(self):
42
+ """
43
+ Clean up after each test.
44
+ """
45
+ self.chatbot.storage.drop()
46
+ self.chatbot.storage.close()
47
+
48
+ def test_concurrent_requests_no_exhaustion(self):
49
+ """
50
+ Test that concurrent requests don't exhaust the connection pool.
51
+
52
+ This was the original vulnerability - concurrent get_response() calls
53
+ would leak sessions and exhaust the pool.
54
+ """
55
+ num_threads = 30 # More than pool_size + max_overflow
56
+ responses = []
57
+ errors = []
58
+
59
+ def make_request():
60
+ try:
61
+ response = self.chatbot.get_response('Hi')
62
+ responses.append(str(response))
63
+ except Exception as e:
64
+ errors.append(e)
65
+
66
+ threads = []
67
+ for _ in range(num_threads):
68
+ t = threading.Thread(target=make_request)
69
+ threads.append(t)
70
+ t.start()
71
+
72
+ # Wait for all threads to complete
73
+ for t in threads:
74
+ t.join(timeout=10)
75
+
76
+ # Verify no errors occurred
77
+ self.assertEqual(len(errors), 0,
78
+ f"Connection pool exhaustion occurred: {errors}")
79
+
80
+ # Verify all threads got responses
81
+ self.assertEqual(len(responses), num_threads,
82
+ "Not all threads received responses")
83
+
84
+ def test_rapid_sequential_requests(self):
85
+ """
86
+ Test that rapid sequential requests properly release connections.
87
+ """
88
+ num_requests = 50 # More than pool size
89
+
90
+ for i in range(num_requests):
91
+ response = self.chatbot.get_response(f'Request {i}')
92
+ self.assertIsNotNone(response)
93
+
94
+ def test_partial_filter_consumption(self):
95
+ """
96
+ Test that partially consuming filter() results doesn't leak sessions.
97
+
98
+ This was a key part of the vulnerability - the filter() generator
99
+ would not close the session if iteration stopped early.
100
+ """
101
+ # Create many statements
102
+ trainer = ListTrainer(self.chatbot)
103
+ for i in range(100):
104
+ trainer.train([f'Question {i}', f'Answer {i}'])
105
+
106
+ # Partially consume filter results many times
107
+ for _ in range(50):
108
+ results = self.chatbot.storage.filter()
109
+ # Only consume first result
110
+ first = next(results, None)
111
+ self.assertIsNotNone(first)
112
+ # Don't consume the rest - this should still clean up the session
113
+
114
+ # If sessions weren't cleaned up, this would fail
115
+ response = self.chatbot.get_response('Hi')
116
+ self.assertIsNotNone(response)
117
+
118
+ def test_concurrent_training(self):
119
+ """
120
+ Test that concurrent training operations don't leak connections.
121
+ """
122
+ errors = []
123
+
124
+ def train_batch(batch_id):
125
+ try:
126
+ trainer = ListTrainer(self.chatbot)
127
+ trainer.train([
128
+ f'Training question {batch_id}',
129
+ f'Training answer {batch_id}',
130
+ ])
131
+ except Exception as e:
132
+ errors.append(e)
133
+
134
+ threads = []
135
+ for i in range(20):
136
+ t = threading.Thread(target=train_batch, args=(i,))
137
+ threads.append(t)
138
+ t.start()
139
+
140
+ for t in threads:
141
+ t.join(timeout=10)
142
+
143
+ self.assertEqual(len(errors), 0,
144
+ f"Errors during concurrent training: {errors}")
145
+
146
+ def test_session_cleanup_on_exception(self):
147
+ """
148
+ Test that sessions are cleaned up even when exceptions occur.
149
+ """
150
+ # Force an error during a database operation
151
+ try:
152
+ # Create a statement with invalid data
153
+ self.chatbot.storage.create(
154
+ text='', # Empty text might cause issues
155
+ in_response_to=None
156
+ )
157
+ except Exception:
158
+ pass # Expected to fail
159
+
160
+ # Verify the pool is still usable
161
+ response = self.chatbot.get_response('Hi')
162
+ self.assertIsNotNone(response)
163
+
164
+ def test_scoped_session_thread_safety(self):
165
+ """
166
+ Test that scoped_session provides proper thread isolation.
167
+ """
168
+ results = {}
169
+
170
+ def check_session_isolation(thread_id):
171
+ # Each thread should get its own session
172
+ session1 = self.chatbot.storage.Session()
173
+ time.sleep(0.01) # Small delay to encourage thread interleaving
174
+ session2 = self.chatbot.storage.Session()
175
+
176
+ # In the same thread, scoped_session should return the same session
177
+ results[thread_id] = (id(session1) == id(session2))
178
+
179
+ session1.close()
180
+ # After close, scoped_session should return the same instance
181
+ # (it doesn't create a new one, just reuses the thread-local one)
182
+
183
+ threads = []
184
+ for i in range(5):
185
+ t = threading.Thread(target=check_session_isolation, args=(i,))
186
+ threads.append(t)
187
+ t.start()
188
+
189
+ for t in threads:
190
+ t.join()
191
+
192
+ # All threads should have gotten consistent session behavior
193
+ for thread_id, same_session in results.items():
194
+ self.assertTrue(same_session,
195
+ f"Thread {thread_id} got different sessions")
196
+
197
+
198
+ class ConnectionPoolConfigTestCase(unittest.TestCase):
199
+ """
200
+ Test cases for connection pool configuration options.
201
+ Note: These tests are skipped for SQLite since it uses SingletonThreadPool.
202
+ """
203
+
204
+ def test_pool_config_not_applied_to_sqlite(self):
205
+ """
206
+ Test that pool config is not applied to SQLite (uses SingletonThreadPool).
207
+ """
208
+ chatbot = ChatBot(
209
+ 'SQLiteBot',
210
+ database_uri='sqlite://',
211
+ pool_size=3, # Should be ignored
212
+ max_overflow=2, # Should be ignored
213
+ )
214
+
215
+ # SQLite uses SingletonThreadPool, not QueuePool
216
+ from sqlalchemy.pool import SingletonThreadPool
217
+ self.assertIsInstance(chatbot.storage.engine.pool, SingletonThreadPool)
218
+
219
+ chatbot.storage.close()
220
+
221
+ @unittest.skip("Requires PostgreSQL/MySQL database for testing")
222
+ def test_custom_pool_size_postgres(self):
223
+ """
224
+ Test that custom pool_size is respected for PostgreSQL.
225
+ """
226
+ # This test would require a PostgreSQL connection
227
+ # chatbot = ChatBot(
228
+ # 'ConfigBot',
229
+ # database_uri='postgresql://user:pass@localhost/test',
230
+ # pool_size=3,
231
+ # max_overflow=2,
232
+ # )
233
+ # self.assertEqual(chatbot.storage.engine.pool.size(), 3)
234
+ # chatbot.storage.close()
235
+ pass
236
+
237
+ @unittest.skip("Requires PostgreSQL/MySQL database for testing")
238
+ def test_default_pool_config_postgres(self):
239
+ """
240
+ Test that default pool configuration is applied for PostgreSQL.
241
+ """
242
+ # This test would require a PostgreSQL connection
243
+ # chatbot = ChatBot(
244
+ # 'DefaultBot',
245
+ # database_uri='postgresql://user:pass@localhost/test',
246
+ # )
247
+ # pool = chatbot.storage.engine.pool
248
+ # self.assertEqual(pool.size(), 10) # Default pool_size
249
+ # chatbot.storage.close()
250
+ pass
251
+
252
+ @unittest.skip("Requires PostgreSQL/MySQL database for testing")
253
+ def test_pool_pre_ping_enabled_postgres(self):
254
+ """
255
+ Test that pool_pre_ping is enabled by default for PostgreSQL.
256
+ """
257
+ # This test would require a PostgreSQL connection
258
+ # chatbot = ChatBot(
259
+ # 'PingBot',
260
+ # database_uri='postgresql://user:pass@localhost/test',
261
+ # )
262
+ # self.assertTrue(chatbot.storage.engine.pool._pre_ping)
263
+ # chatbot.storage.close()
264
+ pass
265
+
266
+
267
+ if __name__ == '__main__':
268
+ unittest.main()
@@ -0,0 +1,152 @@
1
+ """
2
+ Proof-of-Concept test for the connection pool exhaustion vulnerability (CVE-TBD).
3
+
4
+ This test demonstrates that the original vulnerability has been fixed.
5
+ The PoC from the security report would cause timeout errors before the fix.
6
+ """
7
+ import threading
8
+ import tempfile
9
+ import os
10
+ import unittest
11
+ from chatterbot import ChatBot
12
+ from chatterbot.trainers import ListTrainer
13
+
14
+
15
+ class PoC_VulnerabilityTestCase(unittest.TestCase):
16
+ """
17
+ Test case demonstrating the fix for the connection pool exhaustion vulnerability.
18
+
19
+ This uses a file-based SQLite database which doesn't have the same
20
+ thread-restrictions as in-memory databases.
21
+ """
22
+
23
+ def setUp(self):
24
+ """
25
+ Set up test fixtures.
26
+ """
27
+ # Create a temporary database file
28
+ self.db_fd, self.db_path = tempfile.mkstemp(suffix='.sqlite3')
29
+
30
+ self.chatbot = ChatBot(
31
+ 'TestBot',
32
+ database_uri=f'sqlite:///{self.db_path}',
33
+ )
34
+
35
+ # Train with basic data
36
+ trainer = ListTrainer(self.chatbot)
37
+ trainer.train(['hello', 'hi there'])
38
+
39
+ def tearDown(self):
40
+ """
41
+ Clean up after test.
42
+ """
43
+ self.chatbot.storage.close()
44
+ os.close(self.db_fd)
45
+ os.unlink(self.db_path)
46
+
47
+ def test_original_poc_no_longer_causes_timeout(self):
48
+ """
49
+ Test that the original PoC from the security report no longer causes errors.
50
+
51
+ Before the fix: This would cause SQLAlchemy TimeoutError due to pool exhaustion
52
+ After the fix: All requests complete successfully
53
+ """
54
+ def attack():
55
+ try:
56
+ response = self.chatbot.get_response("hello")
57
+ results.append(('success', str(response)))
58
+ except Exception as e:
59
+ results.append(('error', str(e)))
60
+
61
+ results = []
62
+ threads = []
63
+
64
+ # Original PoC used 30 threads
65
+ for _ in range(30):
66
+ t = threading.Thread(target=attack)
67
+ t.start()
68
+ threads.append(t)
69
+
70
+ for t in threads:
71
+ t.join(timeout=15) # Should complete well before timeout
72
+
73
+ # Count successes and errors
74
+ successes = [r for r in results if r[0] == 'success']
75
+ errors = [r for r in results if r[0] == 'error']
76
+
77
+ # Before fix: Would have many TimeoutError exceptions
78
+ # After fix: All should succeed
79
+ self.assertEqual(len(errors), 0,
80
+ f"Got {len(errors)} errors (expected 0). Errors: {[e[1] for e in errors][:5]}")
81
+ self.assertEqual(len(successes), 30,
82
+ f"Got {len(successes)} successes (expected 30)")
83
+
84
+ def test_high_concurrency_sustained(self):
85
+ """
86
+ Test sustained high concurrency doesn't cause issues.
87
+ """
88
+ request_count = 0
89
+ lock = threading.Lock()
90
+
91
+ def make_many_requests():
92
+ nonlocal request_count
93
+ for _ in range(10):
94
+ try:
95
+ self.chatbot.get_response("hello")
96
+ with lock:
97
+ request_count += 1
98
+ except Exception:
99
+ pass
100
+
101
+ threads = []
102
+ for _ in range(10): # 10 threads × 10 requests = 100 total
103
+ t = threading.Thread(target=make_many_requests)
104
+ t.start()
105
+ threads.append(t)
106
+
107
+ for t in threads:
108
+ t.join(timeout=30)
109
+
110
+ # Should have completed all 100 requests
111
+ self.assertGreater(request_count, 90, # Allow for some timing issues
112
+ f"Only completed {request_count}/100 requests")
113
+
114
+
115
+ class SequentialPerformanceTestCase(unittest.TestCase):
116
+ """
117
+ Test that the fixes don't negatively impact single-threaded performance.
118
+ """
119
+
120
+ def setUp(self):
121
+ """
122
+ Set up test fixtures.
123
+ """
124
+ self.db_fd, self.db_path = tempfile.mkstemp(suffix='.sqlite3')
125
+
126
+ self.chatbot = ChatBot(
127
+ 'PerfBot',
128
+ database_uri=f'sqlite:///{self.db_path}',
129
+ )
130
+
131
+ trainer = ListTrainer(self.chatbot)
132
+ trainer.train(['hello', 'hi', 'how are you', 'good'])
133
+
134
+ def tearDown(self):
135
+ """
136
+ Clean up.
137
+ """
138
+ self.chatbot.storage.close()
139
+ os.close(self.db_fd)
140
+ os.unlink(self.db_path)
141
+
142
+ def test_sequential_requests_still_work(self):
143
+ """
144
+ Test that normal sequential usage still works correctly.
145
+ """
146
+ for i in range(50):
147
+ response = self.chatbot.get_response(f"message {i}")
148
+ self.assertIsNotNone(response)
149
+
150
+
151
+ if __name__ == '__main__':
152
+ unittest.main()
File without changes
File without changes
File without changes
File without changes