slidge 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (102) hide show
  1. slidge/__init__.py +3 -5
  2. slidge/__main__.py +2 -197
  3. slidge/__version__.py +5 -0
  4. slidge/command/adhoc.py +40 -17
  5. slidge/command/admin.py +24 -12
  6. slidge/command/base.py +10 -8
  7. slidge/command/categories.py +13 -3
  8. slidge/command/chat_command.py +29 -2
  9. slidge/command/register.py +32 -16
  10. slidge/command/user.py +106 -13
  11. slidge/contact/contact.py +254 -50
  12. slidge/contact/roster.py +124 -53
  13. slidge/core/config.py +19 -13
  14. slidge/core/dispatcher/__init__.py +3 -0
  15. slidge/core/{gateway → dispatcher}/caps.py +12 -8
  16. slidge/core/{gateway → dispatcher}/disco.py +10 -18
  17. slidge/core/dispatcher/message/__init__.py +10 -0
  18. slidge/core/dispatcher/message/chat_state.py +40 -0
  19. slidge/core/dispatcher/message/marker.py +62 -0
  20. slidge/core/dispatcher/message/message.py +397 -0
  21. slidge/core/dispatcher/muc/__init__.py +12 -0
  22. slidge/core/dispatcher/muc/admin.py +98 -0
  23. slidge/core/{gateway → dispatcher/muc}/mam.py +25 -17
  24. slidge/core/dispatcher/muc/misc.py +121 -0
  25. slidge/core/dispatcher/muc/owner.py +96 -0
  26. slidge/core/{gateway → dispatcher/muc}/ping.py +11 -17
  27. slidge/core/dispatcher/presence.py +176 -0
  28. slidge/core/dispatcher/registration.py +85 -0
  29. slidge/core/{gateway → dispatcher}/search.py +9 -16
  30. slidge/core/dispatcher/session_dispatcher.py +84 -0
  31. slidge/core/dispatcher/util.py +174 -0
  32. slidge/core/{gateway/vcard_temp.py → dispatcher/vcard.py} +35 -19
  33. slidge/core/{gateway/base.py → gateway.py} +176 -153
  34. slidge/core/mixins/__init__.py +11 -1
  35. slidge/core/mixins/attachment.py +106 -67
  36. slidge/core/mixins/avatar.py +94 -25
  37. slidge/core/mixins/base.py +10 -4
  38. slidge/core/mixins/db.py +18 -0
  39. slidge/core/mixins/disco.py +0 -10
  40. slidge/core/mixins/lock.py +10 -8
  41. slidge/core/mixins/message.py +11 -195
  42. slidge/core/mixins/message_maker.py +17 -9
  43. slidge/core/mixins/message_text.py +211 -0
  44. slidge/core/mixins/presence.py +17 -4
  45. slidge/core/pubsub.py +114 -288
  46. slidge/core/session.py +101 -40
  47. slidge/db/__init__.py +4 -0
  48. slidge/db/alembic/__init__.py +0 -0
  49. slidge/db/alembic/env.py +64 -0
  50. slidge/db/alembic/old_user_store.py +183 -0
  51. slidge/db/alembic/script.py.mako +26 -0
  52. slidge/db/alembic/versions/09f27f098baa_add_missing_attributes_in_room.py +36 -0
  53. slidge/db/alembic/versions/15b0bd83407a_remove_bogus_unique_constraints_on_room_.py +85 -0
  54. slidge/db/alembic/versions/2461390c0af2_store_contacts_caps_verstring_in_db.py +36 -0
  55. slidge/db/alembic/versions/29f5280c61aa_store_subject_setter_in_room.py +37 -0
  56. slidge/db/alembic/versions/2b1f45ab7379_store_room_subject_setter_by_nickname.py +41 -0
  57. slidge/db/alembic/versions/3071e0fa69d4_add_contact_client_type.py +52 -0
  58. slidge/db/alembic/versions/45c24cc73c91_add_bob.py +42 -0
  59. slidge/db/alembic/versions/5bd48bfdffa2_lift_room_legacy_id_constraint.py +61 -0
  60. slidge/db/alembic/versions/82a4af84b679_add_muc_history_filled.py +48 -0
  61. slidge/db/alembic/versions/8b993243a536_add_vcard_content_to_contact_table.py +43 -0
  62. slidge/db/alembic/versions/8d2ced764698_rely_on_db_to_store_contacts_rooms_and_.py +139 -0
  63. slidge/db/alembic/versions/aa9d82a7f6ef_db_creation.py +101 -0
  64. slidge/db/alembic/versions/abba1ae0edb3_store_avatar_legacy_id_in_the_contact_.py +79 -0
  65. slidge/db/alembic/versions/b33993e87db3_move_everything_to_persistent_db.py +214 -0
  66. slidge/db/alembic/versions/b64b1a793483_add_source_and_legacy_id_for_archived_.py +52 -0
  67. slidge/db/alembic/versions/c4a8ec35a0e8_per_room_user_nick.py +34 -0
  68. slidge/db/alembic/versions/e91195719c2c_store_users_avatars_persistently.py +26 -0
  69. slidge/db/avatar.py +205 -0
  70. slidge/db/meta.py +72 -0
  71. slidge/db/models.py +405 -0
  72. slidge/db/store.py +1257 -0
  73. slidge/group/archive.py +58 -14
  74. slidge/group/bookmarks.py +89 -65
  75. slidge/group/participant.py +107 -40
  76. slidge/group/room.py +402 -213
  77. slidge/main.py +202 -0
  78. slidge/migration.py +45 -1
  79. slidge/slixfix/__init__.py +31 -1
  80. slidge/{core/gateway → slixfix}/delivery_receipt.py +1 -1
  81. slidge/slixfix/roster.py +13 -4
  82. slidge/slixfix/xep_0292/vcard4.py +1 -87
  83. slidge/util/archive_msg.py +2 -1
  84. slidge/util/db.py +4 -228
  85. slidge/util/test.py +91 -4
  86. slidge/util/types.py +39 -4
  87. slidge/util/util.py +45 -2
  88. {slidge-0.1.3.dist-info → slidge-0.2.0.dist-info}/METADATA +10 -5
  89. slidge-0.2.0.dist-info/RECORD +131 -0
  90. slidge-0.2.0.dist-info/entry_points.txt +3 -0
  91. slidge/core/cache.py +0 -183
  92. slidge/core/gateway/__init__.py +0 -3
  93. slidge/core/gateway/muc_admin.py +0 -35
  94. slidge/core/gateway/presence.py +0 -95
  95. slidge/core/gateway/registration.py +0 -53
  96. slidge/core/gateway/session_dispatcher.py +0 -804
  97. slidge/util/schema.sql +0 -126
  98. slidge/util/sql.py +0 -508
  99. slidge-0.1.3.dist-info/RECORD +0 -96
  100. slidge-0.1.3.dist-info/entry_points.txt +0 -3
  101. {slidge-0.1.3.dist-info → slidge-0.2.0.dist-info}/LICENSE +0 -0
  102. {slidge-0.1.3.dist-info → slidge-0.2.0.dist-info}/WHEEL +0 -0
slidge/db/store.py ADDED
@@ -0,0 +1,1257 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ import logging
6
+ import uuid
7
+ from contextlib import contextmanager
8
+ from datetime import datetime, timedelta, timezone
9
+ from mimetypes import guess_extension
10
+ from typing import TYPE_CHECKING, Collection, Iterator, Optional, Type
11
+
12
+ from slixmpp import JID, Iq, Message, Presence
13
+ from slixmpp.exceptions import XMPPError
14
+ from slixmpp.plugins.xep_0231.stanza import BitsOfBinary
15
+ from sqlalchemy import Engine, delete, select, update
16
+ from sqlalchemy.orm import Session, attributes, load_only
17
+ from sqlalchemy.sql.functions import count
18
+
19
+ from ..core import config
20
+ from ..util.archive_msg import HistoryMessage
21
+ from ..util.types import URL, CachedPresence, ClientType
22
+ from ..util.types import Hat as HatTuple
23
+ from ..util.types import MamMetadata, MucAffiliation, MucRole, Sticker
24
+ from .meta import Base
25
+ from .models import (
26
+ ArchivedMessage,
27
+ ArchivedMessageSource,
28
+ Attachment,
29
+ Avatar,
30
+ Bob,
31
+ Contact,
32
+ ContactSent,
33
+ GatewayUser,
34
+ Hat,
35
+ LegacyIdsMulti,
36
+ Participant,
37
+ Room,
38
+ XmppIdsMulti,
39
+ XmppToLegacyEnum,
40
+ XmppToLegacyIds,
41
+ participant_hats,
42
+ )
43
+
44
+ if TYPE_CHECKING:
45
+ from ..contact.contact import LegacyContact
46
+ from ..group.participant import LegacyParticipant
47
+ from ..group.room import LegacyMUC
48
+
49
+
50
+ class EngineMixin:
51
+ def __init__(self, engine: Engine):
52
+ self._engine = engine
53
+
54
+ @contextmanager
55
+ def session(self, **session_kwargs) -> Iterator[Session]:
56
+ global _session
57
+ if _session is not None:
58
+ yield _session
59
+ return
60
+ with Session(self._engine, **session_kwargs) as session:
61
+ _session = session
62
+ try:
63
+ yield session
64
+ finally:
65
+ _session = None
66
+
67
+
68
+ class UpdatedMixin(EngineMixin):
69
+ model: Type[Base] = NotImplemented
70
+
71
+ def __init__(self, *a, **kw):
72
+ super().__init__(*a, **kw)
73
+ with self.session() as session:
74
+ session.execute(update(self.model).values(updated=False)) # type:ignore
75
+ session.commit()
76
+
77
+ def get_by_pk(self, pk: int) -> Optional[Base]:
78
+ with self.session() as session:
79
+ return session.execute(
80
+ select(self.model).where(self.model.id == pk) # type:ignore
81
+ ).scalar()
82
+
83
+
84
+ class SlidgeStore(EngineMixin):
85
+ def __init__(self, engine: Engine):
86
+ super().__init__(engine)
87
+ self.users = UserStore(engine)
88
+ self.avatars = AvatarStore(engine)
89
+ self.contacts = ContactStore(engine)
90
+ self.mam = MAMStore(engine)
91
+ self.multi = MultiStore(engine)
92
+ self.attachments = AttachmentStore(engine)
93
+ self.rooms = RoomStore(engine)
94
+ self.sent = SentStore(engine)
95
+ self.participants = ParticipantStore(engine)
96
+ self.bob = BobStore(engine)
97
+
98
+
99
+ class UserStore(EngineMixin):
100
+ def new(self, jid: JID, legacy_module_data: dict) -> GatewayUser:
101
+ if jid.resource:
102
+ jid = JID(jid.bare)
103
+ with self.session(expire_on_commit=False) as session:
104
+ user = session.execute(
105
+ select(GatewayUser).where(GatewayUser.jid == jid)
106
+ ).scalar()
107
+ if user is not None:
108
+ return user
109
+ user = GatewayUser(jid=jid, legacy_module_data=legacy_module_data)
110
+ session.add(user)
111
+ session.commit()
112
+ return user
113
+
114
+ def update(self, user: GatewayUser):
115
+ # https://github.com/sqlalchemy/sqlalchemy/discussions/6473
116
+ attributes.flag_modified(user, "legacy_module_data")
117
+ attributes.flag_modified(user, "preferences")
118
+ with self.session() as session:
119
+ session.add(user)
120
+ session.commit()
121
+
122
+ def get_all(self) -> Iterator[GatewayUser]:
123
+ with self.session() as session:
124
+ yield from session.execute(select(GatewayUser)).scalars()
125
+
126
+ def get(self, jid: JID) -> Optional[GatewayUser]:
127
+ with self.session() as session:
128
+ return session.execute(
129
+ select(GatewayUser).where(GatewayUser.jid == jid.bare)
130
+ ).scalar()
131
+
132
+ def get_by_stanza(self, stanza: Iq | Message | Presence) -> Optional[GatewayUser]:
133
+ return self.get(stanza.get_from())
134
+
135
+ def delete(self, jid: JID) -> None:
136
+ with self.session() as session:
137
+ session.delete(self.get(jid))
138
+ session.commit()
139
+
140
+ def set_avatar_hash(self, pk: int, h: str | None = None) -> None:
141
+ with self.session() as session:
142
+ session.execute(
143
+ update(GatewayUser).where(GatewayUser.id == pk).values(avatar_hash=h)
144
+ )
145
+ session.commit()
146
+
147
+
148
+ class AvatarStore(EngineMixin):
149
+ def get_by_url(self, url: URL | str) -> Optional[Avatar]:
150
+ with self.session() as session:
151
+ return session.execute(select(Avatar).where(Avatar.url == url)).scalar()
152
+
153
+ def get_by_pk(self, pk: int) -> Optional[Avatar]:
154
+ with self.session() as session:
155
+ return session.execute(select(Avatar).where(Avatar.id == pk)).scalar()
156
+
157
+ def delete_by_pk(self, pk: int):
158
+ with self.session() as session:
159
+ session.execute(delete(Avatar).where(Avatar.id == pk))
160
+ session.commit()
161
+
162
+ def get_all(self) -> Iterator[Avatar]:
163
+ with self.session() as session:
164
+ yield from session.execute(select(Avatar)).scalars()
165
+
166
+
167
+ class SentStore(EngineMixin):
168
+ def set_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
169
+ with self.session() as session:
170
+ msg = (
171
+ session.query(XmppToLegacyIds)
172
+ .filter(XmppToLegacyIds.user_account_id == user_pk)
173
+ .filter(XmppToLegacyIds.legacy_id == legacy_id)
174
+ .filter(XmppToLegacyIds.xmpp_id == xmpp_id)
175
+ .scalar()
176
+ )
177
+ if msg is None:
178
+ msg = XmppToLegacyIds(user_account_id=user_pk)
179
+ else:
180
+ log.debug("Resetting a DM from sent store")
181
+ msg.legacy_id = legacy_id
182
+ msg.xmpp_id = xmpp_id
183
+ msg.type = XmppToLegacyEnum.DM
184
+ session.add(msg)
185
+ session.commit()
186
+
187
+ def get_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]:
188
+ with self.session() as session:
189
+ return session.execute(
190
+ select(XmppToLegacyIds.xmpp_id)
191
+ .where(XmppToLegacyIds.user_account_id == user_pk)
192
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
193
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM)
194
+ ).scalar()
195
+
196
+ def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
197
+ with self.session() as session:
198
+ return session.execute(
199
+ select(XmppToLegacyIds.legacy_id)
200
+ .where(XmppToLegacyIds.user_account_id == user_pk)
201
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
202
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM)
203
+ ).scalar()
204
+
205
+ def set_group_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
206
+ with self.session() as session:
207
+ msg = XmppToLegacyIds(
208
+ user_account_id=user_pk,
209
+ legacy_id=legacy_id,
210
+ xmpp_id=xmpp_id,
211
+ type=XmppToLegacyEnum.GROUP_CHAT,
212
+ )
213
+ session.add(msg)
214
+ session.commit()
215
+
216
+ def get_group_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]:
217
+ with self.session() as session:
218
+ return session.execute(
219
+ select(XmppToLegacyIds.xmpp_id)
220
+ .where(XmppToLegacyIds.user_account_id == user_pk)
221
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
222
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT)
223
+ ).scalar()
224
+
225
+ def get_group_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
226
+ with self.session() as session:
227
+ return session.execute(
228
+ select(XmppToLegacyIds.legacy_id)
229
+ .where(XmppToLegacyIds.user_account_id == user_pk)
230
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
231
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT)
232
+ ).scalar()
233
+
234
+ def set_thread(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
235
+ with self.session() as session:
236
+ msg = XmppToLegacyIds(
237
+ user_account_id=user_pk,
238
+ legacy_id=legacy_id,
239
+ xmpp_id=xmpp_id,
240
+ type=XmppToLegacyEnum.THREAD,
241
+ )
242
+ session.add(msg)
243
+ session.commit()
244
+
245
+ def get_legacy_thread(self, user_pk: int, xmpp_id: str) -> Optional[str]:
246
+ with self.session() as session:
247
+ return session.execute(
248
+ select(XmppToLegacyIds.legacy_id)
249
+ .where(XmppToLegacyIds.user_account_id == user_pk)
250
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
251
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.THREAD)
252
+ ).scalar()
253
+
254
+ def was_sent_by_user(self, user_pk: int, legacy_id: str) -> bool:
255
+ with self.session() as session:
256
+ return (
257
+ session.execute(
258
+ select(XmppToLegacyIds.legacy_id)
259
+ .where(XmppToLegacyIds.user_account_id == user_pk)
260
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
261
+ ).scalar()
262
+ is not None
263
+ )
264
+
265
+
266
+ class ContactStore(UpdatedMixin):
267
+ model = Contact
268
+
269
+ def __init__(self, *a, **k):
270
+ super().__init__(*a, **k)
271
+ with self.session() as session:
272
+ session.execute(update(Contact).values(cached_presence=False))
273
+ session.commit()
274
+
275
+ def get_all(self, user_pk: int) -> Iterator[Contact]:
276
+ with self.session() as session:
277
+ yield from session.execute(
278
+ select(Contact).where(Contact.user_account_id == user_pk)
279
+ ).scalars()
280
+
281
+ def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Contact]:
282
+ with self.session() as session:
283
+ return session.execute(
284
+ select(Contact)
285
+ .where(Contact.jid == jid.bare)
286
+ .where(Contact.user_account_id == user_pk)
287
+ ).scalar()
288
+
289
+ def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Contact]:
290
+ with self.session() as session:
291
+ return session.execute(
292
+ select(Contact)
293
+ .where(Contact.legacy_id == legacy_id)
294
+ .where(Contact.user_account_id == user_pk)
295
+ ).scalar()
296
+
297
+ def update_nick(self, contact_pk: int, nick: Optional[str]) -> None:
298
+ with self.session() as session:
299
+ session.execute(
300
+ update(Contact).where(Contact.id == contact_pk).values(nick=nick)
301
+ )
302
+ session.commit()
303
+
304
+ def get_presence(self, contact_pk: int) -> Optional[CachedPresence]:
305
+ with self.session() as session:
306
+ presence = session.execute(
307
+ select(
308
+ Contact.last_seen,
309
+ Contact.ptype,
310
+ Contact.pstatus,
311
+ Contact.pshow,
312
+ Contact.cached_presence,
313
+ ).where(Contact.id == contact_pk)
314
+ ).first()
315
+ if presence is None or not presence[-1]:
316
+ return None
317
+ return CachedPresence(*presence[:-1])
318
+
319
+ def set_presence(self, contact_pk: int, presence: CachedPresence) -> None:
320
+ with self.session() as session:
321
+ session.execute(
322
+ update(Contact)
323
+ .where(Contact.id == contact_pk)
324
+ .values(**presence._asdict(), cached_presence=True)
325
+ )
326
+ session.commit()
327
+
328
+ def reset_presence(self, contact_pk: int):
329
+ with self.session() as session:
330
+ session.execute(
331
+ update(Contact)
332
+ .where(Contact.id == contact_pk)
333
+ .values(
334
+ last_seen=None,
335
+ ptype=None,
336
+ pstatus=None,
337
+ pshow=None,
338
+ cached_presence=False,
339
+ )
340
+ )
341
+ session.commit()
342
+
343
+ def set_avatar(
344
+ self, contact_pk: int, avatar_pk: Optional[int], avatar_legacy_id: Optional[str]
345
+ ):
346
+ with self.session() as session:
347
+ session.execute(
348
+ update(Contact)
349
+ .where(Contact.id == contact_pk)
350
+ .values(avatar_id=avatar_pk, avatar_legacy_id=avatar_legacy_id)
351
+ )
352
+ session.commit()
353
+
354
+ def get_avatar_legacy_id(self, contact_pk: int) -> Optional[str]:
355
+ with self.session() as session:
356
+ contact = session.execute(
357
+ select(Contact).where(Contact.id == contact_pk)
358
+ ).scalar()
359
+ if contact is None or contact.avatar is None:
360
+ return None
361
+ return contact.avatar_legacy_id
362
+
363
+ def update(self, contact: "LegacyContact", commit=True) -> int:
364
+ with self.session() as session:
365
+ if contact.contact_pk is None:
366
+ if contact.cached_presence is not None:
367
+ presence_kwargs = contact.cached_presence._asdict()
368
+ presence_kwargs["cached_presence"] = True
369
+ else:
370
+ presence_kwargs = {}
371
+ row = Contact(
372
+ jid=contact.jid.bare,
373
+ legacy_id=str(contact.legacy_id),
374
+ user_account_id=contact.user_pk,
375
+ **presence_kwargs,
376
+ )
377
+ else:
378
+ row = (
379
+ session.query(Contact)
380
+ .filter(Contact.id == contact.contact_pk)
381
+ .one()
382
+ )
383
+ row.nick = contact.name
384
+ row.is_friend = contact.is_friend
385
+ row.added_to_roster = contact.added_to_roster
386
+ row.updated = True
387
+ row.extra_attributes = contact.serialize_extra_attributes()
388
+ row.caps_ver = contact._caps_ver
389
+ row.vcard = contact._vcard
390
+ row.vcard_fetched = contact._vcard_fetched
391
+ row.client_type = contact.client_type
392
+ session.add(row)
393
+ if commit:
394
+ session.commit()
395
+ return row.id
396
+
397
+ def set_vcard(self, contact_pk: int, vcard: str | None) -> None:
398
+ with self.session() as session:
399
+ session.execute(
400
+ update(Contact)
401
+ .where(Contact.id == contact_pk)
402
+ .values(vcard=vcard, vcard_fetched=True)
403
+ )
404
+ session.commit()
405
+
406
+ def add_to_sent(self, contact_pk: int, msg_id: str) -> None:
407
+ with self.session() as session:
408
+ if (
409
+ session.query(ContactSent.id)
410
+ .where(ContactSent.contact_id == contact_pk)
411
+ .where(ContactSent.msg_id == msg_id)
412
+ .first()
413
+ ) is not None:
414
+ log.warning(
415
+ "Contact %s has already sent message %s", contact_pk, msg_id
416
+ )
417
+ return
418
+ new = ContactSent(contact_id=contact_pk, msg_id=msg_id)
419
+ session.add(new)
420
+ session.commit()
421
+
422
+ def pop_sent_up_to(self, contact_pk: int, msg_id: str) -> list[str]:
423
+ result = []
424
+ to_del = []
425
+ with self.session() as session:
426
+ for row in session.execute(
427
+ select(ContactSent)
428
+ .where(ContactSent.contact_id == contact_pk)
429
+ .order_by(ContactSent.id)
430
+ ).scalars():
431
+ to_del.append(row.id)
432
+ result.append(row.msg_id)
433
+ if row.msg_id == msg_id:
434
+ break
435
+ for row_id in to_del:
436
+ session.execute(delete(ContactSent).where(ContactSent.id == row_id))
437
+ return result
438
+
439
+ def set_friend(self, contact_pk: int, is_friend: bool) -> None:
440
+ with self.session() as session:
441
+ session.execute(
442
+ update(Contact)
443
+ .where(Contact.id == contact_pk)
444
+ .values(is_friend=is_friend)
445
+ )
446
+ session.commit()
447
+
448
+ def set_added_to_roster(self, contact_pk: int, value: bool) -> None:
449
+ with self.session() as session:
450
+ session.execute(
451
+ update(Contact)
452
+ .where(Contact.id == contact_pk)
453
+ .values(added_to_roster=value)
454
+ )
455
+ session.commit()
456
+
457
+ def delete(self, contact_pk: int) -> None:
458
+ with self.session() as session:
459
+ session.execute(delete(Contact).where(Contact.id == contact_pk))
460
+ session.commit()
461
+
462
+ def set_client_type(self, contact_pk: int, value: ClientType):
463
+ with self.session() as session:
464
+ session.execute(
465
+ update(Contact)
466
+ .where(Contact.id == contact_pk)
467
+ .values(client_type=value)
468
+ )
469
+ session.commit()
470
+
471
+
472
+ class MAMStore(EngineMixin):
473
+ def __init__(self, *a, **kw):
474
+ super().__init__(*a, **kw)
475
+ with self.session() as session:
476
+ session.execute(
477
+ update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL)
478
+ )
479
+ session.commit()
480
+
481
+ def nuke_older_than(self, days: int) -> None:
482
+ with self.session() as session:
483
+ session.execute(
484
+ delete(ArchivedMessage).where(
485
+ ArchivedMessage.timestamp < datetime.now() - timedelta(days=days)
486
+ )
487
+ )
488
+ session.commit()
489
+
490
+ def add_message(
491
+ self,
492
+ room_pk: int,
493
+ message: HistoryMessage,
494
+ archive_only: bool,
495
+ legacy_msg_id: str | None,
496
+ ) -> None:
497
+ with self.session() as session:
498
+ source = (
499
+ ArchivedMessageSource.BACKFILL
500
+ if archive_only
501
+ else ArchivedMessageSource.LIVE
502
+ )
503
+ existing = session.execute(
504
+ select(ArchivedMessage)
505
+ .where(ArchivedMessage.room_id == room_pk)
506
+ .where(ArchivedMessage.stanza_id == message.id)
507
+ ).scalar()
508
+ if existing is None and legacy_msg_id is not None:
509
+ existing = session.execute(
510
+ select(ArchivedMessage)
511
+ .where(ArchivedMessage.room_id == room_pk)
512
+ .where(ArchivedMessage.legacy_id == legacy_msg_id)
513
+ ).scalar()
514
+ if existing is not None:
515
+ log.debug("Updating message %s in room %s", message.id, room_pk)
516
+ existing.timestamp = message.when
517
+ existing.stanza = str(message.stanza)
518
+ existing.author_jid = message.stanza.get_from()
519
+ existing.source = source
520
+ existing.legacy_id = legacy_msg_id
521
+ session.add(existing)
522
+ session.commit()
523
+ return
524
+ mam_msg = ArchivedMessage(
525
+ stanza_id=message.id,
526
+ timestamp=message.when,
527
+ stanza=str(message.stanza),
528
+ author_jid=message.stanza.get_from(),
529
+ room_id=room_pk,
530
+ source=source,
531
+ legacy_id=legacy_msg_id,
532
+ )
533
+ session.add(mam_msg)
534
+ session.commit()
535
+
536
+ def get_messages(
537
+ self,
538
+ room_pk: int,
539
+ start_date: Optional[datetime] = None,
540
+ end_date: Optional[datetime] = None,
541
+ before_id: Optional[str] = None,
542
+ after_id: Optional[str] = None,
543
+ ids: Collection[str] = (),
544
+ last_page_n: Optional[int] = None,
545
+ sender: Optional[str] = None,
546
+ flip=False,
547
+ ) -> Iterator[HistoryMessage]:
548
+
549
+ with self.session() as session:
550
+ q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
551
+ if start_date is not None:
552
+ q = q.where(ArchivedMessage.timestamp >= start_date)
553
+ if end_date is not None:
554
+ q = q.where(ArchivedMessage.timestamp <= end_date)
555
+ if before_id is not None:
556
+ stamp = session.execute(
557
+ select(ArchivedMessage.timestamp).where(
558
+ ArchivedMessage.stanza_id == before_id
559
+ )
560
+ ).scalar()
561
+ if stamp is None:
562
+ raise XMPPError(
563
+ "item-not-found",
564
+ f"Message {before_id} not found",
565
+ )
566
+ q = q.where(ArchivedMessage.timestamp < stamp)
567
+ if after_id is not None:
568
+ stamp = session.execute(
569
+ select(ArchivedMessage.timestamp).where(
570
+ ArchivedMessage.stanza_id == after_id
571
+ )
572
+ ).scalar()
573
+ if stamp is None:
574
+ raise XMPPError(
575
+ "item-not-found",
576
+ f"Message {after_id} not found",
577
+ )
578
+ q = q.where(ArchivedMessage.timestamp > stamp)
579
+ if ids:
580
+ q = q.filter(ArchivedMessage.stanza_id.in_(ids))
581
+ if sender is not None:
582
+ q = q.where(ArchivedMessage.author_jid == sender)
583
+ if flip:
584
+ q = q.order_by(ArchivedMessage.timestamp.desc())
585
+ else:
586
+ q = q.order_by(ArchivedMessage.timestamp.asc())
587
+ msgs = list(session.execute(q).scalars())
588
+ if ids and len(msgs) != len(ids):
589
+ raise XMPPError(
590
+ "item-not-found",
591
+ "One of the requested messages IDs could not be found "
592
+ "with the given constraints.",
593
+ )
594
+ if last_page_n is not None:
595
+ if flip:
596
+ msgs = msgs[:last_page_n]
597
+ else:
598
+ msgs = msgs[-last_page_n:]
599
+ for h in msgs:
600
+ yield HistoryMessage(
601
+ stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=timezone.utc)
602
+ )
603
+
604
+ def get_first(self, room_pk: int, with_legacy_id=False) -> ArchivedMessage | None:
605
+ with self.session() as session:
606
+ q = (
607
+ select(ArchivedMessage)
608
+ .where(ArchivedMessage.room_id == room_pk)
609
+ .order_by(ArchivedMessage.timestamp.asc())
610
+ )
611
+ if with_legacy_id:
612
+ q = q.filter(ArchivedMessage.legacy_id.isnot(None))
613
+ return session.execute(q).scalar()
614
+
615
+ def get_last(
616
+ self, room_pk: int, source: ArchivedMessageSource | None = None
617
+ ) -> ArchivedMessage | None:
618
+ with self.session() as session:
619
+ q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
620
+
621
+ if source is not None:
622
+ q = q.where(ArchivedMessage.source == source)
623
+
624
+ return session.execute(
625
+ q.order_by(ArchivedMessage.timestamp.desc())
626
+ ).scalar()
627
+
628
+ def get_first_and_last(self, room_pk: int) -> list[MamMetadata]:
629
+ r = []
630
+ with self.session():
631
+ first = self.get_first(room_pk)
632
+ if first is not None:
633
+ r.append(MamMetadata(first.stanza_id, first.timestamp))
634
+ last = self.get_last(room_pk)
635
+ if last is not None:
636
+ r.append(MamMetadata(last.stanza_id, last.timestamp))
637
+ return r
638
+
639
+ def get_most_recent_with_legacy_id(
640
+ self, room_pk: int, source: ArchivedMessageSource | None = None
641
+ ) -> ArchivedMessage | None:
642
+ with self.session() as session:
643
+ q = (
644
+ select(ArchivedMessage)
645
+ .where(ArchivedMessage.room_id == room_pk)
646
+ .where(ArchivedMessage.legacy_id.isnot(None))
647
+ )
648
+ if source is not None:
649
+ q = q.where(ArchivedMessage.source == source)
650
+ return session.execute(
651
+ q.order_by(ArchivedMessage.timestamp.desc())
652
+ ).scalar()
653
+
654
+ def get_least_recent_with_legacy_id_after(
655
+ self, room_pk: int, after_id: str, source=ArchivedMessageSource.LIVE
656
+ ) -> ArchivedMessage | None:
657
+ with self.session() as session:
658
+ after_timestamp = (
659
+ session.query(ArchivedMessage.timestamp)
660
+ .filter(ArchivedMessage.room_id == room_pk)
661
+ .filter(ArchivedMessage.legacy_id == after_id)
662
+ .scalar()
663
+ )
664
+ q = (
665
+ select(ArchivedMessage)
666
+ .where(ArchivedMessage.room_id == room_pk)
667
+ .where(ArchivedMessage.legacy_id.isnot(None))
668
+ .where(ArchivedMessage.source == source)
669
+ .where(ArchivedMessage.timestamp > after_timestamp)
670
+ )
671
+ return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar()
672
+
673
+ def get_by_legacy_id(self, room_pk: int, legacy_id: str) -> ArchivedMessage | None:
674
+ with self.session() as session:
675
+ return (
676
+ session.query(ArchivedMessage)
677
+ .filter(ArchivedMessage.room_id == room_pk)
678
+ .filter(ArchivedMessage.legacy_id == legacy_id)
679
+ .first()
680
+ )
681
+
682
+
683
+ class MultiStore(EngineMixin):
684
+ def get_xmpp_ids(self, user_pk: int, xmpp_id: str) -> list[str]:
685
+ with self.session() as session:
686
+ multi = session.execute(
687
+ select(XmppIdsMulti)
688
+ .where(XmppIdsMulti.xmpp_id == xmpp_id)
689
+ .where(XmppIdsMulti.user_account_id == user_pk)
690
+ ).scalar()
691
+ if multi is None:
692
+ return []
693
+ return [m.xmpp_id for m in multi.legacy_ids_multi.xmpp_ids]
694
+
695
+ def set_xmpp_ids(
696
+ self, user_pk: int, legacy_msg_id: str, xmpp_ids: list[str], fail=False
697
+ ) -> None:
698
+ with self.session() as session:
699
+ existing = session.execute(
700
+ select(LegacyIdsMulti)
701
+ .where(LegacyIdsMulti.user_account_id == user_pk)
702
+ .where(LegacyIdsMulti.legacy_id == legacy_msg_id)
703
+ ).scalar()
704
+ if existing is not None:
705
+ if fail:
706
+ raise
707
+ log.debug("Resetting multi for %s", legacy_msg_id)
708
+ session.execute(
709
+ delete(LegacyIdsMulti)
710
+ .where(LegacyIdsMulti.user_account_id == user_pk)
711
+ .where(LegacyIdsMulti.legacy_id == legacy_msg_id)
712
+ )
713
+ for i in xmpp_ids:
714
+ session.execute(
715
+ delete(XmppIdsMulti)
716
+ .where(XmppIdsMulti.user_account_id == user_pk)
717
+ .where(XmppIdsMulti.xmpp_id == i)
718
+ )
719
+ session.commit()
720
+ self.set_xmpp_ids(user_pk, legacy_msg_id, xmpp_ids, True)
721
+ return
722
+
723
+ row = LegacyIdsMulti(
724
+ user_account_id=user_pk,
725
+ legacy_id=legacy_msg_id,
726
+ xmpp_ids=[
727
+ XmppIdsMulti(user_account_id=user_pk, xmpp_id=i)
728
+ for i in xmpp_ids
729
+ if i
730
+ ],
731
+ )
732
+ session.add(row)
733
+ session.commit()
734
+
735
+ def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
736
+ with self.session() as session:
737
+ multi = session.execute(
738
+ select(XmppIdsMulti)
739
+ .where(XmppIdsMulti.xmpp_id == xmpp_id)
740
+ .where(XmppIdsMulti.user_account_id == user_pk)
741
+ ).scalar()
742
+ if multi is None:
743
+ return None
744
+ return multi.legacy_ids_multi.legacy_id
745
+
746
+
747
+ class AttachmentStore(EngineMixin):
748
+ def get_url(self, legacy_file_id: str) -> Optional[str]:
749
+ with self.session() as session:
750
+ return session.execute(
751
+ select(Attachment.url).where(
752
+ Attachment.legacy_file_id == legacy_file_id
753
+ )
754
+ ).scalar()
755
+
756
+ def set_url(self, user_pk: int, legacy_file_id: str, url: str) -> None:
757
+ with self.session() as session:
758
+ att = session.execute(
759
+ select(Attachment)
760
+ .where(Attachment.legacy_file_id == legacy_file_id)
761
+ .where(Attachment.user_account_id == user_pk)
762
+ ).scalar()
763
+ if att is None:
764
+ att = Attachment(
765
+ legacy_file_id=legacy_file_id, url=url, user_account_id=user_pk
766
+ )
767
+ session.add(att)
768
+ else:
769
+ att.url = url
770
+ session.commit()
771
+
772
+ def get_sims(self, url: str) -> Optional[str]:
773
+ with self.session() as session:
774
+ return session.execute(
775
+ select(Attachment.sims).where(Attachment.url == url)
776
+ ).scalar()
777
+
778
+ def set_sims(self, url: str, sims: str) -> None:
779
+ with self.session() as session:
780
+ session.execute(
781
+ update(Attachment).where(Attachment.url == url).values(sims=sims)
782
+ )
783
+ session.commit()
784
+
785
+ def get_sfs(self, url: str) -> Optional[str]:
786
+ with self.session() as session:
787
+ return session.execute(
788
+ select(Attachment.sfs).where(Attachment.url == url)
789
+ ).scalar()
790
+
791
+ def set_sfs(self, url: str, sfs: str) -> None:
792
+ with self.session() as session:
793
+ session.execute(
794
+ update(Attachment).where(Attachment.url == url).values(sfs=sfs)
795
+ )
796
+ session.commit()
797
+
798
+ def remove(self, legacy_file_id: str) -> None:
799
+ with self.session() as session:
800
+ session.execute(
801
+ delete(Attachment).where(Attachment.legacy_file_id == legacy_file_id)
802
+ )
803
+ session.commit()
804
+
805
+
806
+ class RoomStore(UpdatedMixin):
807
+ model = Room
808
+
809
+ def __init__(self, *a, **kw):
810
+ super().__init__(*a, **kw)
811
+ with self.session() as session:
812
+ session.execute(
813
+ update(Room).values(
814
+ subject_setter=None,
815
+ user_resources=None,
816
+ history_filled=False,
817
+ participants_filled=False,
818
+ )
819
+ )
820
+ session.commit()
821
+
822
+ def set_avatar(
823
+ self, room_pk: int, avatar_pk: int | None, avatar_legacy_id: str | None
824
+ ) -> None:
825
+ with self.session() as session:
826
+ session.execute(
827
+ update(Room)
828
+ .where(Room.id == room_pk)
829
+ .values(avatar_id=avatar_pk, avatar_legacy_id=avatar_legacy_id)
830
+ )
831
+ session.commit()
832
+
833
+ def get_avatar_legacy_id(self, room_pk: int) -> Optional[str]:
834
+ with self.session() as session:
835
+ room = session.execute(select(Room).where(Room.id == room_pk)).scalar()
836
+ if room is None or room.avatar is None:
837
+ return None
838
+ return room.avatar_legacy_id
839
+
840
+ def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Room]:
841
+ if jid.resource:
842
+ raise TypeError
843
+ with self.session() as session:
844
+ return session.execute(
845
+ select(Room)
846
+ .where(Room.user_account_id == user_pk)
847
+ .where(Room.jid == jid)
848
+ ).scalar()
849
+
850
+ def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Room]:
851
+ with self.session() as session:
852
+ return session.execute(
853
+ select(Room)
854
+ .where(Room.user_account_id == user_pk)
855
+ .where(Room.legacy_id == legacy_id)
856
+ ).scalar()
857
+
858
+ def update_subject_setter(self, room_pk: int, subject_setter: str | None):
859
+ with self.session() as session:
860
+ session.execute(
861
+ update(Room)
862
+ .where(Room.id == room_pk)
863
+ .values(subject_setter=subject_setter)
864
+ )
865
+ session.commit()
866
+
867
+ def update(self, muc: "LegacyMUC") -> int:
868
+ with self.session() as session:
869
+ if muc.pk is None:
870
+ row = Room(
871
+ jid=muc.jid,
872
+ legacy_id=str(muc.legacy_id),
873
+ user_account_id=muc.user_pk,
874
+ )
875
+ else:
876
+ row = session.query(Room).filter(Room.id == muc.pk).one()
877
+
878
+ row.updated = True
879
+ row.extra_attributes = muc.serialize_extra_attributes()
880
+ row.name = muc.name
881
+ row.description = muc.description
882
+ row.user_resources = (
883
+ None
884
+ if not muc._user_resources
885
+ else json.dumps(list(muc._user_resources))
886
+ )
887
+ row.muc_type = muc.type
888
+ row.subject = muc.subject
889
+ row.subject_date = muc.subject_date
890
+ row.subject_setter = muc.subject_setter
891
+ row.participants_filled = muc._participants_filled
892
+ row.n_participants = muc._n_participants
893
+ row.user_nick = muc.user_nick
894
+ session.add(row)
895
+ session.commit()
896
+ return row.id
897
+
898
+ def update_subject_date(
899
+ self, room_pk: int, subject_date: Optional[datetime]
900
+ ) -> None:
901
+ with self.session() as session:
902
+ session.execute(
903
+ update(Room).where(Room.id == room_pk).values(subject_date=subject_date)
904
+ )
905
+ session.commit()
906
+
907
+ def update_subject(self, room_pk: int, subject: Optional[str]) -> None:
908
+ with self.session() as session:
909
+ session.execute(
910
+ update(Room).where(Room.id == room_pk).values(subject=subject)
911
+ )
912
+ session.commit()
913
+
914
+ def update_description(self, room_pk: int, desc: Optional[str]) -> None:
915
+ with self.session() as session:
916
+ session.execute(
917
+ update(Room).where(Room.id == room_pk).values(description=desc)
918
+ )
919
+ session.commit()
920
+
921
+ def update_name(self, room_pk: int, name: Optional[str]) -> None:
922
+ with self.session() as session:
923
+ session.execute(update(Room).where(Room.id == room_pk).values(name=name))
924
+ session.commit()
925
+
926
+ def update_n_participants(self, room_pk: int, n: Optional[int]) -> None:
927
+ with self.session() as session:
928
+ session.execute(
929
+ update(Room).where(Room.id == room_pk).values(n_participants=n)
930
+ )
931
+ session.commit()
932
+
933
+ def update_user_nick(self, room_pk, nick: str) -> None:
934
+ with self.session() as session:
935
+ session.execute(
936
+ update(Room).where(Room.id == room_pk).values(user_nick=nick)
937
+ )
938
+ session.commit()
939
+
940
+ def delete(self, room_pk: int) -> None:
941
+ with self.session() as session:
942
+ session.execute(delete(Room).where(Room.id == room_pk))
943
+ session.execute(delete(Participant).where(Participant.room_id == room_pk))
944
+ session.commit()
945
+
946
+ def set_resource(self, room_pk: int, resources: set[str]) -> None:
947
+ with self.session() as session:
948
+ session.execute(
949
+ update(Room)
950
+ .where(Room.id == room_pk)
951
+ .values(
952
+ user_resources=(
953
+ None if not resources else json.dumps(list(resources))
954
+ )
955
+ )
956
+ )
957
+ session.commit()
958
+
959
+ def nickname_is_available(self, room_pk: int, nickname: str) -> bool:
960
+ with self.session() as session:
961
+ return (
962
+ session.execute(
963
+ select(Participant)
964
+ .where(Participant.room_id == room_pk)
965
+ .where(Participant.nickname == nickname)
966
+ ).scalar()
967
+ is None
968
+ )
969
+
970
+ def set_participants_filled(self, room_pk: int, val=True) -> None:
971
+ with self.session() as session:
972
+ session.execute(
973
+ update(Room).where(Room.id == room_pk).values(participants_filled=val)
974
+ )
975
+ session.commit()
976
+
977
+ def set_history_filled(self, room_pk: int, val=True) -> None:
978
+ with self.session() as session:
979
+ session.execute(
980
+ update(Room).where(Room.id == room_pk).values(history_filled=True)
981
+ )
982
+ session.commit()
983
+
984
+ def get_all(self, user_pk: int) -> Iterator[Room]:
985
+ with self.session() as session:
986
+ yield from session.execute(
987
+ select(Room).where(Room.user_account_id == user_pk)
988
+ ).scalars()
989
+
990
+ def get_all_jid_and_names(self, user_pk: int) -> Iterator[Room]:
991
+ with self.session() as session:
992
+ yield from session.scalars(
993
+ select(Room)
994
+ .filter(Room.user_account_id == user_pk)
995
+ .options(load_only(Room.jid, Room.name))
996
+ .order_by(Room.name)
997
+ ).all()
998
+
999
+
1000
+ class ParticipantStore(EngineMixin):
1001
+ def __init__(self, *a, **kw):
1002
+ super().__init__(*a, **kw)
1003
+ with self.session() as session:
1004
+ session.execute(delete(participant_hats))
1005
+ session.execute(delete(Hat))
1006
+ session.execute(delete(Participant))
1007
+ session.commit()
1008
+
1009
+ def add(self, room_pk: int, nickname: str) -> int:
1010
+ with self.session() as session:
1011
+ existing = session.execute(
1012
+ select(Participant.id)
1013
+ .where(Participant.room_id == room_pk)
1014
+ .where(Participant.nickname == nickname)
1015
+ ).scalar()
1016
+ if existing is not None:
1017
+ return existing
1018
+ participant = Participant(room_id=room_pk, nickname=nickname)
1019
+ session.add(participant)
1020
+ session.commit()
1021
+ return participant.id
1022
+
1023
+ def get_by_nickname(self, room_pk: int, nickname: str) -> Optional[Participant]:
1024
+ with self.session() as session:
1025
+ return session.execute(
1026
+ select(Participant)
1027
+ .where(Participant.room_id == room_pk)
1028
+ .where(Participant.nickname == nickname)
1029
+ ).scalar()
1030
+
1031
+ def get_by_resource(self, room_pk: int, resource: str) -> Optional[Participant]:
1032
+ with self.session() as session:
1033
+ return session.execute(
1034
+ select(Participant)
1035
+ .where(Participant.room_id == room_pk)
1036
+ .where(Participant.resource == resource)
1037
+ ).scalar()
1038
+
1039
+ def get_by_contact(self, room_pk: int, contact_pk: int) -> Optional[Participant]:
1040
+ with self.session() as session:
1041
+ return session.execute(
1042
+ select(Participant)
1043
+ .where(Participant.room_id == room_pk)
1044
+ .where(Participant.contact_id == contact_pk)
1045
+ ).scalar()
1046
+
1047
+ def get_all(self, room_pk: int, user_included=True) -> Iterator[Participant]:
1048
+ with self.session() as session:
1049
+ q = select(Participant).where(Participant.room_id == room_pk)
1050
+ if not user_included:
1051
+ q = q.where(~Participant.is_user)
1052
+ yield from session.execute(q).scalars()
1053
+
1054
+ def get_for_contact(self, contact_pk: int) -> Iterator[Participant]:
1055
+ with self.session() as session:
1056
+ yield from session.execute(
1057
+ select(Participant).where(Participant.contact_id == contact_pk)
1058
+ ).scalars()
1059
+
1060
+ def update(self, participant: "LegacyParticipant") -> None:
1061
+ with self.session() as session:
1062
+ session.execute(
1063
+ update(Participant)
1064
+ .where(Participant.id == participant.pk)
1065
+ .values(
1066
+ resource=participant.jid.resource,
1067
+ affiliation=participant.affiliation,
1068
+ role=participant.role,
1069
+ presence_sent=participant._presence_sent, # type:ignore
1070
+ # hats=[self.add_hat(h.uri, h.title) for h in participant._hats],
1071
+ is_user=participant.is_user,
1072
+ contact_id=(
1073
+ None
1074
+ if participant.contact is None
1075
+ else participant.contact.contact_pk
1076
+ ),
1077
+ )
1078
+ )
1079
+ session.commit()
1080
+
1081
+ def add_hat(self, uri: str, title: str) -> Hat:
1082
+ with self.session() as session:
1083
+ existing = session.execute(
1084
+ select(Hat).where(Hat.uri == uri).where(Hat.title == title)
1085
+ ).scalar()
1086
+ if existing is not None:
1087
+ return existing
1088
+ hat = Hat(uri=uri, title=title)
1089
+ session.add(hat)
1090
+ session.commit()
1091
+ return hat
1092
+
1093
+ def set_presence_sent(self, participant_pk: int) -> None:
1094
+ with self.session() as session:
1095
+ session.execute(
1096
+ update(Participant)
1097
+ .where(Participant.id == participant_pk)
1098
+ .values(presence_sent=True)
1099
+ )
1100
+ session.commit()
1101
+
1102
+ def set_affiliation(self, participant_pk: int, affiliation: MucAffiliation) -> None:
1103
+ with self.session() as session:
1104
+ session.execute(
1105
+ update(Participant)
1106
+ .where(Participant.id == participant_pk)
1107
+ .values(affiliation=affiliation)
1108
+ )
1109
+ session.commit()
1110
+
1111
+ def set_role(self, participant_pk: int, role: MucRole) -> None:
1112
+ with self.session() as session:
1113
+ session.execute(
1114
+ update(Participant)
1115
+ .where(Participant.id == participant_pk)
1116
+ .values(role=role)
1117
+ )
1118
+ session.commit()
1119
+
1120
+ def set_hats(self, participant_pk: int, hats: list[HatTuple]) -> None:
1121
+ with self.session() as session:
1122
+ part = session.execute(
1123
+ select(Participant).where(Participant.id == participant_pk)
1124
+ ).scalar()
1125
+ if part is None:
1126
+ raise ValueError
1127
+ part.hats.clear()
1128
+ for h in hats:
1129
+ hat = self.add_hat(*h)
1130
+ if hat in part.hats:
1131
+ continue
1132
+ part.hats.append(hat)
1133
+ session.commit()
1134
+
1135
+ def delete(self, participant_pk: int) -> None:
1136
+ with self.session() as session:
1137
+ session.execute(delete(Participant).where(Participant.id == participant_pk))
1138
+
1139
+ def get_count(self, room_pk: int) -> int:
1140
+ with self.session() as session:
1141
+ return session.query(
1142
+ count(Participant.id).filter(Participant.room_id == room_pk)
1143
+ ).scalar()
1144
+
1145
+
1146
+ class BobStore(EngineMixin):
1147
+ _ATTR_MAP = {
1148
+ "sha-1": "sha_1",
1149
+ "sha1": "sha_1",
1150
+ "sha-256": "sha_256",
1151
+ "sha256": "sha_256",
1152
+ "sha-512": "sha_512",
1153
+ "sha512": "sha_512",
1154
+ }
1155
+
1156
+ _ALG_MAP = {
1157
+ "sha_1": hashlib.sha1,
1158
+ "sha_256": hashlib.sha256,
1159
+ "sha_512": hashlib.sha512,
1160
+ }
1161
+
1162
+ def __init__(self, *a, **k):
1163
+ super().__init__(*a, **k)
1164
+ self.root_dir = config.HOME_DIR / "slidge_stickers"
1165
+ self.root_dir.mkdir(exist_ok=True)
1166
+
1167
+ @staticmethod
1168
+ def __split_cid(cid: str) -> list[str]:
1169
+ return cid.removesuffix("@bob.xmpp.org").split("+")
1170
+
1171
+ def __get_condition(self, cid: str):
1172
+ alg_name, digest = self.__split_cid(cid)
1173
+ attr = self._ATTR_MAP.get(alg_name)
1174
+ if attr is None:
1175
+ log.warning("Unknown hash algo: %s", alg_name)
1176
+ return None
1177
+ return getattr(Bob, attr) == digest
1178
+
1179
+ def get(self, cid: str) -> Bob | None:
1180
+ with self.session() as session:
1181
+ try:
1182
+ return session.query(Bob).filter(self.__get_condition(cid)).scalar()
1183
+ except ValueError:
1184
+ log.warning("Cannot get Bob with CID: %s", cid)
1185
+ return None
1186
+
1187
+ def get_sticker(self, cid: str) -> Sticker | None:
1188
+ bob = self.get(cid)
1189
+ if bob is None:
1190
+ return None
1191
+ return Sticker(
1192
+ self.root_dir / bob.file_name,
1193
+ bob.content_type,
1194
+ {h: getattr(bob, h) for h in self._ALG_MAP},
1195
+ )
1196
+
1197
+ def get_bob(self, _jid, _node, _ifrom, cid: str) -> BitsOfBinary | None:
1198
+ stored = self.get(cid)
1199
+ if stored is None:
1200
+ return None
1201
+ bob = BitsOfBinary()
1202
+ bob["data"] = (self.root_dir / stored.file_name).read_bytes()
1203
+ if stored.content_type is not None:
1204
+ bob["type"] = stored.content_type
1205
+ bob["cid"] = cid
1206
+ return bob
1207
+
1208
+ def del_bob(self, _jid, _node, _ifrom, cid: str) -> None:
1209
+ with self.session() as orm:
1210
+ try:
1211
+ file_name = orm.scalar(
1212
+ delete(Bob)
1213
+ .where(self.__get_condition(cid))
1214
+ .returning(Bob.file_name)
1215
+ )
1216
+ except ValueError:
1217
+ log.warning("Cannot delete Bob with CID: %s", cid)
1218
+ return None
1219
+ if file_name is None:
1220
+ log.warning("No BoB with CID: %s", cid)
1221
+ return None
1222
+ (self.root_dir / file_name).unlink()
1223
+ orm.commit()
1224
+
1225
+ def set_bob(self, _jid, _node, _ifrom, bob: BitsOfBinary) -> None:
1226
+ cid = bob["cid"]
1227
+ try:
1228
+ alg_name, digest = self.__split_cid(cid)
1229
+ except ValueError:
1230
+ log.warning("Cannot set Bob with CID: %s", cid)
1231
+ return
1232
+ attr = self._ATTR_MAP.get(alg_name)
1233
+ if attr is None:
1234
+ log.warning("Cannot set BoB with unknown hash algo: %s", alg_name)
1235
+ return None
1236
+ with self.session() as orm:
1237
+ existing = self.get(bob["cid"])
1238
+ if existing is not None:
1239
+ log.debug("Bob already known")
1240
+ return
1241
+ bytes_ = bob["data"]
1242
+ path = self.root_dir / uuid.uuid4().hex
1243
+ if bob["type"]:
1244
+ path = path.with_suffix(guess_extension(bob["type"]) or "")
1245
+ path.write_bytes(bytes_)
1246
+ hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()}
1247
+ if hashes[attr] != digest:
1248
+ raise ValueError(
1249
+ "The given CID does not correspond to the result of our hash"
1250
+ )
1251
+ row = Bob(file_name=path.name, content_type=bob["type"] or None, **hashes)
1252
+ orm.add(row)
1253
+ orm.commit()
1254
+
1255
+
1256
+ log = logging.getLogger(__name__)
1257
+ _session: Optional[Session] = None