slidge 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. slidge/__init__.py +61 -0
  2. slidge/__main__.py +192 -0
  3. slidge/command/__init__.py +28 -0
  4. slidge/command/adhoc.py +258 -0
  5. slidge/command/admin.py +193 -0
  6. slidge/command/base.py +441 -0
  7. slidge/command/categories.py +3 -0
  8. slidge/command/chat_command.py +288 -0
  9. slidge/command/register.py +179 -0
  10. slidge/command/user.py +250 -0
  11. slidge/contact/__init__.py +8 -0
  12. slidge/contact/contact.py +452 -0
  13. slidge/contact/roster.py +192 -0
  14. slidge/core/__init__.py +3 -0
  15. slidge/core/cache.py +183 -0
  16. slidge/core/config.py +209 -0
  17. slidge/core/gateway/__init__.py +3 -0
  18. slidge/core/gateway/base.py +892 -0
  19. slidge/core/gateway/caps.py +63 -0
  20. slidge/core/gateway/delivery_receipt.py +52 -0
  21. slidge/core/gateway/disco.py +80 -0
  22. slidge/core/gateway/mam.py +75 -0
  23. slidge/core/gateway/muc_admin.py +35 -0
  24. slidge/core/gateway/ping.py +66 -0
  25. slidge/core/gateway/presence.py +95 -0
  26. slidge/core/gateway/registration.py +53 -0
  27. slidge/core/gateway/search.py +102 -0
  28. slidge/core/gateway/session_dispatcher.py +757 -0
  29. slidge/core/gateway/vcard_temp.py +130 -0
  30. slidge/core/mixins/__init__.py +19 -0
  31. slidge/core/mixins/attachment.py +506 -0
  32. slidge/core/mixins/avatar.py +167 -0
  33. slidge/core/mixins/base.py +31 -0
  34. slidge/core/mixins/disco.py +130 -0
  35. slidge/core/mixins/lock.py +31 -0
  36. slidge/core/mixins/message.py +398 -0
  37. slidge/core/mixins/message_maker.py +154 -0
  38. slidge/core/mixins/presence.py +217 -0
  39. slidge/core/mixins/recipient.py +43 -0
  40. slidge/core/pubsub.py +525 -0
  41. slidge/core/session.py +752 -0
  42. slidge/group/__init__.py +10 -0
  43. slidge/group/archive.py +125 -0
  44. slidge/group/bookmarks.py +163 -0
  45. slidge/group/participant.py +440 -0
  46. slidge/group/room.py +1095 -0
  47. slidge/migration.py +18 -0
  48. slidge/py.typed +0 -0
  49. slidge/slixfix/__init__.py +68 -0
  50. slidge/slixfix/link_preview/__init__.py +10 -0
  51. slidge/slixfix/link_preview/link_preview.py +17 -0
  52. slidge/slixfix/link_preview/stanza.py +99 -0
  53. slidge/slixfix/roster.py +60 -0
  54. slidge/slixfix/xep_0077/__init__.py +10 -0
  55. slidge/slixfix/xep_0077/register.py +289 -0
  56. slidge/slixfix/xep_0077/stanza.py +104 -0
  57. slidge/slixfix/xep_0100/__init__.py +5 -0
  58. slidge/slixfix/xep_0100/gateway.py +121 -0
  59. slidge/slixfix/xep_0100/stanza.py +9 -0
  60. slidge/slixfix/xep_0153/__init__.py +10 -0
  61. slidge/slixfix/xep_0153/stanza.py +25 -0
  62. slidge/slixfix/xep_0153/vcard_avatar.py +23 -0
  63. slidge/slixfix/xep_0264/__init__.py +5 -0
  64. slidge/slixfix/xep_0264/stanza.py +36 -0
  65. slidge/slixfix/xep_0264/thumbnail.py +23 -0
  66. slidge/slixfix/xep_0292/__init__.py +5 -0
  67. slidge/slixfix/xep_0292/vcard4.py +100 -0
  68. slidge/slixfix/xep_0313/__init__.py +12 -0
  69. slidge/slixfix/xep_0313/mam.py +262 -0
  70. slidge/slixfix/xep_0313/stanza.py +359 -0
  71. slidge/slixfix/xep_0317/__init__.py +5 -0
  72. slidge/slixfix/xep_0317/hats.py +17 -0
  73. slidge/slixfix/xep_0317/stanza.py +28 -0
  74. slidge/slixfix/xep_0356_old/__init__.py +7 -0
  75. slidge/slixfix/xep_0356_old/privilege.py +167 -0
  76. slidge/slixfix/xep_0356_old/stanza.py +44 -0
  77. slidge/slixfix/xep_0424/__init__.py +9 -0
  78. slidge/slixfix/xep_0424/retraction.py +77 -0
  79. slidge/slixfix/xep_0424/stanza.py +28 -0
  80. slidge/slixfix/xep_0490/__init__.py +8 -0
  81. slidge/slixfix/xep_0490/mds.py +47 -0
  82. slidge/slixfix/xep_0490/stanza.py +17 -0
  83. slidge/util/__init__.py +15 -0
  84. slidge/util/archive_msg.py +61 -0
  85. slidge/util/conf.py +206 -0
  86. slidge/util/db.py +229 -0
  87. slidge/util/schema.sql +126 -0
  88. slidge/util/sql.py +508 -0
  89. slidge/util/test.py +295 -0
  90. slidge/util/types.py +180 -0
  91. slidge/util/util.py +295 -0
  92. slidge-0.1.0.dist-info/LICENSE +661 -0
  93. slidge-0.1.0.dist-info/METADATA +109 -0
  94. slidge-0.1.0.dist-info/RECORD +96 -0
  95. slidge-0.1.0.dist-info/WHEEL +4 -0
  96. slidge-0.1.0.dist-info/entry_points.txt +3 -0
slidge/util/schema.sql ADDED
@@ -0,0 +1,126 @@
1
+ CREATE TABLE user(
2
+ id INTEGER PRIMARY KEY,
3
+ jid TEXT UNIQUE
4
+ );
5
+
6
+ CREATE TABLE muc(
7
+ id INTEGER PRIMARY KEY,
8
+ jid TEXT,
9
+ user_id INTEGER,
10
+ FOREIGN KEY(user_id) REFERENCES user(id),
11
+ UNIQUE(user_id, jid)
12
+ );
13
+
14
+ CREATE TABLE mam_message(
15
+ id INTEGER PRIMARY KEY,
16
+ message_id TEXT,
17
+ sent_on INTEGER,
18
+ sender_jid TEXT,
19
+ xml TEXT,
20
+ muc_id INTEGER,
21
+ user_id INTEGER,
22
+ FOREIGN KEY(muc_id) REFERENCES muc(id),
23
+ FOREIGN KEY(user_id) REFERENCES user(id),
24
+ UNIQUE(user_id, muc_id, message_id)
25
+ );
26
+
27
+ CREATE INDEX mam_sent_on ON mam_message(sent_on);
28
+ CREATE INDEX muc_jid ON muc(jid);
29
+
30
+ CREATE TABLE session_message_sent(
31
+ id INTEGER PRIMARY KEY,
32
+ legacy_id UNIQUE,
33
+ xmpp_id TEXT,
34
+ user_id INTEGER,
35
+ FOREIGN KEY(user_id) REFERENCES user(id)
36
+ );
37
+
38
+ CREATE INDEX session_message_sent_legacy_id
39
+ ON session_message_sent(legacy_id);
40
+ CREATE INDEX session_message_sent_xmpp_id
41
+ ON session_message_sent(xmpp_id);
42
+
43
+ CREATE TABLE session_message_sent_muc(
44
+ id INTEGER PRIMARY KEY,
45
+ legacy_id UNIQUE,
46
+ xmpp_id TEXT,
47
+ user_id INTEGER,
48
+ FOREIGN KEY(user_id) REFERENCES user(id)
49
+ );
50
+
51
+ CREATE INDEX session_message_sent_muc_legacy_id
52
+ ON session_message_sent_muc(legacy_id);
53
+ CREATE INDEX session_message_sent_muc_xmpp_id
54
+ ON session_message_sent_muc(xmpp_id);
55
+
56
+ CREATE TABLE session_thread_sent_muc(
57
+ id INTEGER PRIMARY KEY,
58
+ legacy_id UNIQUE,
59
+ xmpp_id TEXT,
60
+ user_id INTEGER,
61
+ FOREIGN KEY(user_id) REFERENCES user(id)
62
+ );
63
+
64
+ CREATE INDEX session_thread_sent_muc_legacy_id
65
+ ON session_thread_sent_muc(legacy_id);
66
+ CREATE INDEX session_thread_sent_muc_xmpp_id
67
+ ON session_thread_sent_muc(xmpp_id);
68
+
69
+
70
+ CREATE TABLE attachment(
71
+ id INTEGER PRIMARY KEY,
72
+ legacy_id UNIQUE,
73
+ url TEXT UNIQUE,
74
+ sims TEXT,
75
+ sfs TEXT
76
+ );
77
+
78
+ CREATE INDEX attachment_legacy_id ON attachment(legacy_id);
79
+ CREATE INDEX attachment_url ON attachment(url);
80
+
81
+ CREATE TABLE attachment_legacy_msg_id(
82
+ id INTEGER PRIMARY KEY,
83
+ legacy_id UNIQUE
84
+ );
85
+
86
+ CREATE TABLE attachment_xmpp_ids(
87
+ id INTEGER PRIMARY KEY,
88
+ legacy_msg_id INTEGER,
89
+ xmpp_id TEXT,
90
+ FOREIGN KEY(legacy_msg_id) REFERENCES attachment_legacy_msg_id(id)
91
+ );
92
+
93
+ CREATE TABLE nick(
94
+ id INTEGER PRIMARY KEY,
95
+ jid UNIQUE,
96
+ nick TEXT,
97
+ user_id INTEGER,
98
+ FOREIGN KEY(user_id) REFERENCES user(id),
99
+ UNIQUE(jid, user_id)
100
+ );
101
+
102
+ CREATE INDEX nick_jid ON nick(jid);
103
+
104
+
105
+ CREATE TABLE avatar(
106
+ id INTEGER PRIMARY KEY,
107
+ jid TEXT UNIQUE,
108
+ cached_id TEXT
109
+ );
110
+
111
+ CREATE INDEX avatar_jid ON avatar(jid);
112
+
113
+
114
+ CREATE TABLE presence(
115
+ id INTEGER PRIMARY KEY,
116
+ jid TEXT,
117
+ last_seen INTEGER,
118
+ ptype TEXT,
119
+ pstatus TEXT,
120
+ pshow TEXT,
121
+ user_id INTEGER,
122
+ FOREIGN KEY(user_id) REFERENCES user(id),
123
+ UNIQUE(jid, user_id)
124
+ );
125
+
126
+ CREATE INDEX presence_jid ON presence(jid);
slidge/util/sql.py ADDED
@@ -0,0 +1,508 @@
1
+ import logging
2
+ import os
3
+ import sqlite3
4
+ import tempfile
5
+ from asyncio import AbstractEventLoop, Task, sleep
6
+ from datetime import datetime, timezone
7
+ from functools import lru_cache
8
+ from pathlib import Path
9
+ from time import time
10
+ from typing import (
11
+ TYPE_CHECKING,
12
+ Collection,
13
+ Generic,
14
+ Iterator,
15
+ NamedTuple,
16
+ Optional,
17
+ TypeVar,
18
+ Union,
19
+ )
20
+
21
+ from slixmpp import JID
22
+ from slixmpp.exceptions import XMPPError
23
+ from slixmpp.types import PresenceShows, PresenceTypes
24
+
25
+ from ..core import config
26
+ from .archive_msg import HistoryMessage
27
+
28
+ if TYPE_CHECKING:
29
+ from .db import GatewayUser
30
+
31
+ KeyType = TypeVar("KeyType")
32
+ ValueType = TypeVar("ValueType")
33
+
34
+
35
+ class CachedPresence(NamedTuple):
36
+ last_seen: Optional[datetime] = None
37
+ ptype: Optional[PresenceTypes] = None
38
+ pstatus: Optional[str] = None
39
+ pshow: Optional[PresenceShows] = None
40
+
41
+
42
+ class MamMetadata(NamedTuple):
43
+ id: str
44
+ sent_on: datetime
45
+
46
+
47
+ class Base:
48
+ def __init__(self):
49
+ handler, filename = tempfile.mkstemp()
50
+
51
+ os.close(handler)
52
+ self.__filename = filename
53
+
54
+ self.con = sqlite3.connect(filename)
55
+ self.cur = self.con.cursor()
56
+ self.cur.executescript((Path(__file__).parent / "schema.sql").read_text())
57
+
58
+ self.__mam_cleanup_task: Optional[Task] = None
59
+
60
+ def __del__(self):
61
+ self.con.close()
62
+ os.unlink(self.__filename)
63
+
64
+
65
+ class MAMMixin(Base):
66
+ def __init__(self):
67
+ super().__init__()
68
+ self.__mam_cleanup_task: Optional[Task] = None
69
+ self.__msg_cur = msg_cur = self.con.cursor()
70
+ msg_cur.row_factory = self.__msg_factory # type:ignore
71
+ self.__metadata_cur = metadata_cur = self.con.cursor()
72
+ metadata_cur.row_factory = self.__metadata_factory # type:ignore
73
+
74
+ @staticmethod
75
+ def __msg_factory(_cur, row: tuple[str, float]) -> HistoryMessage:
76
+ return HistoryMessage(
77
+ row[0], when=datetime.fromtimestamp(row[1], tz=timezone.utc)
78
+ )
79
+
80
+ @staticmethod
81
+ def __metadata_factory(_cur, row: tuple[str, float]) -> MamMetadata:
82
+ return MamMetadata(row[0], datetime.fromtimestamp(row[1], tz=timezone.utc))
83
+
84
+ def mam_nuke(self):
85
+ self.cur.execute("DELETE FROM mam_message")
86
+ self.con.commit()
87
+
88
+ def mam_add_muc(self, jid: str, user: "GatewayUser"):
89
+ try:
90
+ self.cur.execute(
91
+ "INSERT INTO "
92
+ "muc(jid, user_id) "
93
+ "VALUES("
94
+ " ?, "
95
+ " (SELECT id FROM user WHERE jid = ?)"
96
+ ")",
97
+ (jid, user.bare_jid),
98
+ )
99
+ except sqlite3.IntegrityError:
100
+ log.debug("Tried to add a MUC that was already here: (%s, %s)", user, jid)
101
+ else:
102
+ self.con.commit()
103
+
104
+ def mam_add_msg(self, muc_jid: str, msg: "HistoryMessage", user: "GatewayUser"):
105
+ self.cur.execute(
106
+ "REPLACE INTO "
107
+ "mam_message(message_id, sender_jid, sent_on, xml, muc_id, user_id)"
108
+ "VALUES(?, ?, ?, ?,"
109
+ "(SELECT id FROM muc WHERE jid = ?),"
110
+ "(SELECT id FROM user WHERE jid = ?)"
111
+ ")",
112
+ (
113
+ msg.id,
114
+ str(msg.stanza.get_from()),
115
+ msg.when.timestamp(),
116
+ str(msg.stanza),
117
+ muc_jid,
118
+ user.bare_jid,
119
+ ),
120
+ )
121
+ self.con.commit()
122
+
123
+ def mam_launch_cleanup_task(self, loop: AbstractEventLoop):
124
+ self.__mam_cleanup_task = loop.create_task(self.__mam_cleanup())
125
+
126
+ async def __mam_cleanup(self):
127
+ await sleep(6 * 3600)
128
+ self.mam_cleanup()
129
+
130
+ def mam_cleanup(self):
131
+ self.cur.execute(
132
+ "DELETE FROM mam_message WHERE sent_on < ?",
133
+ (time() - config.MAM_MAX_DAYS * 24 * 3600,),
134
+ )
135
+ self.con.commit()
136
+
137
+ def __mam_get_sent_on(self, muc_jid: str, mid: str, user: "GatewayUser"):
138
+ res = self.cur.execute(
139
+ "SELECT sent_on "
140
+ "FROM mam_message "
141
+ "WHERE message_id = ? "
142
+ "AND muc_id = (SELECT id FROM muc WHERE jid = ?) "
143
+ "AND user_id = (SELECT id FROM user WHERE jid = ?)",
144
+ (mid, muc_jid, user.bare_jid),
145
+ )
146
+ row = res.fetchone()
147
+ if row is None:
148
+ raise XMPPError("item-not-found", f"Message {mid} not found")
149
+ return row[0]
150
+
151
+ def __mam_bound(
152
+ self,
153
+ muc_jid: str,
154
+ user: "GatewayUser",
155
+ date: Optional[datetime] = None,
156
+ id_: Optional[str] = None,
157
+ comparator=min,
158
+ ):
159
+ if id_ is not None:
160
+ after_id_sent_on = self.__mam_get_sent_on(muc_jid, id_, user)
161
+ if date:
162
+ timestamp = comparator(after_id_sent_on, date.timestamp())
163
+ else:
164
+ timestamp = after_id_sent_on
165
+ return " AND sent_on > ?", timestamp
166
+ elif date is None:
167
+ raise TypeError
168
+ else:
169
+ return " AND sent_on >= ?", date.timestamp()
170
+
171
+ def mam_get_messages(
172
+ self,
173
+ user: "GatewayUser",
174
+ muc_jid: str,
175
+ start_date: Optional[datetime] = None,
176
+ end_date: Optional[datetime] = None,
177
+ before_id: Optional[str] = None,
178
+ after_id: Optional[str] = None,
179
+ ids: Collection[str] = (),
180
+ last_page_n: Optional[int] = None,
181
+ sender: Optional[str] = None,
182
+ flip=False,
183
+ ) -> Iterator[HistoryMessage]:
184
+ query = (
185
+ "SELECT xml, sent_on FROM mam_message "
186
+ "WHERE muc_id = (SELECT id FROM muc WHERE jid = ?) "
187
+ "AND user_id = (SELECT id FROM user WHERE jid = ?) "
188
+ )
189
+ params: list[Union[str, float, int]] = [muc_jid, user.bare_jid]
190
+
191
+ if start_date or after_id:
192
+ subquery, timestamp = self.__mam_bound(
193
+ muc_jid, user, start_date, after_id, max
194
+ )
195
+ query += subquery
196
+ params.append(timestamp)
197
+ if end_date or before_id:
198
+ subquery, timestamp = self.__mam_bound(
199
+ muc_jid, user, end_date, before_id, min
200
+ )
201
+ query += subquery
202
+ params.append(timestamp)
203
+ if sender:
204
+ query += " AND sender_jid = ?"
205
+ params.append(sender)
206
+ if ids:
207
+ query += f" AND message_id IN ({','.join('?' * len(ids))})"
208
+ params.extend(ids)
209
+ if last_page_n:
210
+ # TODO: optimize query further when <flip> and last page are
211
+ # combined.
212
+ query = f"SELECT * FROM ({query} ORDER BY sent_on DESC LIMIT ?)"
213
+ params.append(last_page_n)
214
+ query += " ORDER BY sent_on"
215
+ if flip:
216
+ query += " DESC"
217
+
218
+ res = self.__msg_cur.execute(query, params)
219
+
220
+ if ids:
221
+ rows = res.fetchall()
222
+ if len(rows) != len(ids):
223
+ raise XMPPError(
224
+ "item-not-found",
225
+ "One of the requested messages IDs could not be found "
226
+ "with the given constraints.",
227
+ )
228
+ for row in rows:
229
+ yield row
230
+
231
+ while row := res.fetchone():
232
+ yield row
233
+
234
+ def mam_get_first_and_last(self, muc_jid: str) -> list[MamMetadata]:
235
+ res = self.__metadata_cur.execute(
236
+ "SELECT message_id, sent_on "
237
+ "FROM mam_message "
238
+ "JOIN muc ON muc.jid = ? "
239
+ "WHERE sent_on = (SELECT MIN(sent_on) FROM mam_message WHERE muc_id = muc.id) "
240
+ " OR sent_on = (SELECT MAX(sent_on) FROM mam_message WHERE muc_id = muc.id) "
241
+ " ORDER BY sent_on",
242
+ (muc_jid,),
243
+ )
244
+ return res.fetchall()
245
+
246
+
247
+ class AttachmentMixin(Base):
248
+ def attachment_remove(self, legacy_id):
249
+ self.cur.execute("DELETE FROM attachment WHERE legacy_id = ?", (legacy_id,))
250
+ self.con.commit()
251
+
252
+ def attachment_store_url(self, legacy_id, url: str):
253
+ self.cur.execute(
254
+ "REPLACE INTO attachment(legacy_id, url) VALUES (?,?)", (legacy_id, url)
255
+ )
256
+ self.con.commit()
257
+
258
+ def attachment_store_sims(self, url: str, sims: str):
259
+ self.cur.execute("UPDATE attachment SET sims = ? WHERE url = ?", (sims, url))
260
+ self.con.commit()
261
+
262
+ def attachment_store_sfs(self, url: str, sfs: str):
263
+ self.cur.execute("UPDATE attachment SET sfs = ? WHERE url = ?", (sfs, url))
264
+ self.con.commit()
265
+
266
+ def attachment_get_url(self, legacy_id):
267
+ res = self.cur.execute(
268
+ "SELECT url FROM attachment WHERE legacy_id = ?", (legacy_id,)
269
+ )
270
+ return first_of_tuple_or_none(res.fetchone())
271
+
272
+ def attachment_get_sims(self, url: str):
273
+ res = self.cur.execute("SELECT sims FROM attachment WHERE url = ?", (url,))
274
+ return first_of_tuple_or_none(res.fetchone())
275
+
276
+ def attachment_get_sfs(self, url: str):
277
+ res = self.cur.execute("SELECT sfs FROM attachment WHERE url = ?", (url,))
278
+ return first_of_tuple_or_none(res.fetchone())
279
+
280
+ def attachment_store_legacy_to_multi_xmpp_msg_ids(
281
+ self, legacy_id, xmpp_ids: list[str]
282
+ ):
283
+ with self.con:
284
+ res = self.cur.execute(
285
+ "INSERT OR IGNORE INTO attachment_legacy_msg_id(legacy_id) VALUES (?)",
286
+ (legacy_id,),
287
+ )
288
+ row_id = res.lastrowid
289
+ # for xmpp_id in xmpp_ids:
290
+ self.cur.executemany(
291
+ "INSERT INTO attachment_xmpp_ids(legacy_msg_id, xmpp_id) VALUES (?, ?)",
292
+ ((row_id, i) for i in xmpp_ids),
293
+ )
294
+
295
+ def attachment_get_xmpp_ids_for_legacy_msg_id(self, legacy_id) -> list:
296
+ res = self.cur.execute(
297
+ "SELECT xmpp_id FROM attachment_xmpp_ids "
298
+ "WHERE legacy_msg_id = (SELECT id FROM attachment_legacy_msg_id WHERE legacy_id = ?)",
299
+ (legacy_id,),
300
+ )
301
+ return [r[0] for r in res.fetchall()]
302
+
303
+ def attachment_get_associated_xmpp_ids(self, xmpp_id: str):
304
+ res = self.cur.execute(
305
+ "SELECT xmpp_id FROM attachment_xmpp_ids "
306
+ "WHERE legacy_msg_id = "
307
+ "(SELECT legacy_msg_id FROM attachment_xmpp_ids WHERE xmpp_id = ?)",
308
+ (xmpp_id,),
309
+ )
310
+ return [r[0] for r in res.fetchall() if r[0] != xmpp_id]
311
+
312
+ def attachment_get_legacy_id_for_xmpp_id(self, xmpp_id: str):
313
+ res = self.cur.execute(
314
+ "SELECT legacy_id FROM attachment_legacy_msg_id "
315
+ "WHERE id = (SELECT legacy_msg_id FROM attachment_xmpp_ids WHERE xmpp_id = ?)",
316
+ (xmpp_id,),
317
+ )
318
+ return first_of_tuple_or_none(res.fetchone())
319
+
320
+
321
+ class NickMixin(Base):
322
+ def nick_get(self, jid: JID, user: "GatewayUser"):
323
+ res = self.cur.execute(
324
+ "SELECT nick FROM nick "
325
+ "WHERE jid = ? "
326
+ "AND user_id = (SELECT id FROM user WHERE jid = ?)",
327
+ (str(jid), user.bare_jid),
328
+ )
329
+ return first_of_tuple_or_none(res.fetchone())
330
+
331
+ def nick_store(self, jid: JID, nick: str, user: "GatewayUser"):
332
+ self.cur.execute(
333
+ "REPLACE INTO nick(jid, nick, user_id) "
334
+ "VALUES (?,?,(SELECT id FROM user WHERE jid = ?))",
335
+ (str(jid), nick, user.bare_jid),
336
+ )
337
+ self.con.commit()
338
+
339
+
340
+ class AvatarMixin(Base):
341
+ def avatar_get(self, jid: JID):
342
+ res = self.cur.execute(
343
+ "SELECT cached_id FROM avatar WHERE jid = ?", (str(jid),)
344
+ )
345
+ return first_of_tuple_or_none(res.fetchone())
346
+
347
+ def avatar_store(self, jid: JID, cached_id: Union[int, str]):
348
+ self.cur.execute(
349
+ "REPLACE INTO avatar(jid, cached_id) VALUES (?,?)", (str(jid), cached_id)
350
+ )
351
+ self.con.commit()
352
+
353
+ def avatar_delete(self, jid: JID):
354
+ self.cur.execute("DELETE FROM avatar WHERE jid = ?", (str(jid),))
355
+ self.con.commit()
356
+
357
+
358
+ class PresenceMixin(Base):
359
+ def __init__(self):
360
+ super().__init__()
361
+ self.__cur = cur = self.con.cursor()
362
+ cur.row_factory = self.__row_factory # type:ignore
363
+
364
+ @staticmethod
365
+ def __row_factory(
366
+ _cur: sqlite3.Cursor,
367
+ row: tuple[
368
+ Optional[int],
369
+ Optional[PresenceTypes],
370
+ Optional[str],
371
+ Optional[PresenceShows],
372
+ ],
373
+ ):
374
+ if row[0] is not None:
375
+ last_seen = datetime.fromtimestamp(row[0], tz=timezone.utc)
376
+ else:
377
+ last_seen = None
378
+ return CachedPresence(last_seen, *row[1:])
379
+
380
+ def presence_nuke(self):
381
+ # useful for tests
382
+ self.cur.execute("DELETE FROM presence")
383
+ self.con.commit()
384
+
385
+ def presence_store(self, jid: JID, presence: CachedPresence, user: "GatewayUser"):
386
+ self.cur.execute(
387
+ "REPLACE INTO presence(jid, last_seen, ptype, pstatus, pshow, user_id) "
388
+ "VALUES (?,?,?,?,?,(SELECT id FROM user WHERE jid = ?))",
389
+ (
390
+ str(jid),
391
+ presence[0].timestamp() if presence[0] else None,
392
+ *presence[1:],
393
+ user.bare_jid,
394
+ ),
395
+ )
396
+ self.con.commit()
397
+
398
+ def presence_delete(self, jid: JID, user: "GatewayUser"):
399
+ self.cur.execute(
400
+ "DELETE FROM presence WHERE (jid = ? and user_id = (SELECT id FROM user WHERE jid = ?))",
401
+ (str(jid), user.bare_jid),
402
+ )
403
+ self.con.commit()
404
+
405
+ def presence_get(self, jid: JID, user: "GatewayUser") -> Optional[CachedPresence]:
406
+ return self.__cur.execute(
407
+ "SELECT last_seen, ptype, pstatus, pshow FROM presence "
408
+ "WHERE jid = ? AND user_id = (SELECT id FROM user WHERE jid = ?)",
409
+ (str(jid), user.bare_jid),
410
+ ).fetchone()
411
+
412
+
413
+ class UserMixin(Base):
414
+ def user_store(self, user: "GatewayUser"):
415
+ try:
416
+ self.cur.execute("INSERT INTO user(jid) VALUES (?)", (user.bare_jid,))
417
+ except sqlite3.IntegrityError:
418
+ log.debug("User has already been added.")
419
+ else:
420
+ self.con.commit()
421
+
422
+ def user_del(self, user: "GatewayUser"):
423
+ self.cur.execute("DELETE FROM user WHERE jid = ?", (user.bare_jid,))
424
+ self.con.commit()
425
+
426
+
427
+ def first_of_tuple_or_none(x: Optional[tuple]):
428
+ if x is None:
429
+ return None
430
+ return x[0]
431
+
432
+
433
+ class SQLBiDict(Generic[KeyType, ValueType]):
434
+ def __init__(
435
+ self,
436
+ table: str,
437
+ key1: str,
438
+ key2: str,
439
+ user: "GatewayUser",
440
+ sql: Optional[Base] = None,
441
+ create_table=False,
442
+ is_inverse=False,
443
+ ):
444
+ if sql is None:
445
+ sql = db
446
+ self.db = sql
447
+ self.table = table
448
+ self.key1 = key1
449
+ self.key2 = key2
450
+ self.user = user
451
+ if create_table:
452
+ sql.cur.execute(
453
+ f"CREATE TABLE {table} (id "
454
+ "INTEGER PRIMARY KEY,"
455
+ "user_id INTEGER,"
456
+ f"{key1} UNIQUE,"
457
+ f"{key2} UNIQUE,"
458
+ f"FOREIGN KEY(user_id) REFERENCES user(id))",
459
+ )
460
+ if is_inverse:
461
+ return
462
+ self.inverse = SQLBiDict[ValueType, KeyType](
463
+ table, key2, key1, user, sql=sql, is_inverse=True
464
+ )
465
+
466
+ def __setitem__(self, key: KeyType, value: ValueType):
467
+ self.db.cur.execute(
468
+ f"REPLACE INTO {self.table}"
469
+ f"(user_id, {self.key1}, {self.key2}) "
470
+ "VALUES ((SELECT id FROM user WHERE jid = ?), ?, ?)",
471
+ (self.user.bare_jid, key, value),
472
+ )
473
+ self.db.con.commit()
474
+
475
+ def __getitem__(self, item: KeyType) -> ValueType:
476
+ v = self.get(item)
477
+ if v is None:
478
+ raise KeyError(item)
479
+ return v
480
+
481
+ def __contains__(self, item: KeyType) -> bool:
482
+ res = self.db.cur.execute(
483
+ f"SELECT {self.key1} FROM {self.table} "
484
+ f"WHERE {self.key1} = ? AND user_id = (SELECT id FROM user WHERE jid = ?)",
485
+ (item, self.user.bare_jid),
486
+ ).fetchone()
487
+ return res is not None
488
+
489
+ @lru_cache(100)
490
+ def get(self, item: KeyType) -> Optional[ValueType]:
491
+ res = self.db.cur.execute(
492
+ f"SELECT {self.key2} FROM {self.table} "
493
+ f"WHERE {self.key1} = ? AND user_id = (SELECT id FROM user WHERE jid = ?)",
494
+ (item, self.user.bare_jid),
495
+ ).fetchone()
496
+ if res is None:
497
+ return res
498
+ return res[0]
499
+
500
+
501
+ class TemporaryDB(
502
+ AvatarMixin, AttachmentMixin, NickMixin, MAMMixin, UserMixin, PresenceMixin
503
+ ):
504
+ pass
505
+
506
+
507
+ db = TemporaryDB()
508
+ log = logging.getLogger(__name__)