slidge 0.1.2__py3-none-any.whl → 0.2.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. slidge/__init__.py +3 -5
  2. slidge/__main__.py +2 -196
  3. slidge/__version__.py +5 -0
  4. slidge/command/adhoc.py +8 -1
  5. slidge/command/admin.py +5 -6
  6. slidge/command/base.py +1 -2
  7. slidge/command/register.py +32 -16
  8. slidge/command/user.py +85 -5
  9. slidge/contact/contact.py +93 -31
  10. slidge/contact/roster.py +54 -39
  11. slidge/core/config.py +13 -7
  12. slidge/core/gateway/base.py +139 -34
  13. slidge/core/gateway/disco.py +2 -4
  14. slidge/core/gateway/mam.py +1 -4
  15. slidge/core/gateway/ping.py +2 -3
  16. slidge/core/gateway/presence.py +1 -1
  17. slidge/core/gateway/registration.py +32 -21
  18. slidge/core/gateway/search.py +3 -5
  19. slidge/core/gateway/session_dispatcher.py +109 -51
  20. slidge/core/gateway/vcard_temp.py +6 -4
  21. slidge/core/mixins/__init__.py +11 -1
  22. slidge/core/mixins/attachment.py +15 -10
  23. slidge/core/mixins/avatar.py +66 -18
  24. slidge/core/mixins/base.py +8 -2
  25. slidge/core/mixins/message.py +11 -7
  26. slidge/core/mixins/message_maker.py +17 -9
  27. slidge/core/mixins/presence.py +14 -4
  28. slidge/core/pubsub.py +54 -212
  29. slidge/core/session.py +65 -33
  30. slidge/db/__init__.py +4 -0
  31. slidge/db/alembic/env.py +64 -0
  32. slidge/db/alembic/script.py.mako +26 -0
  33. slidge/db/alembic/versions/09f27f098baa_add_missing_attributes_in_room.py +36 -0
  34. slidge/db/alembic/versions/29f5280c61aa_store_subject_setter_in_room.py +37 -0
  35. slidge/db/alembic/versions/8d2ced764698_rely_on_db_to_store_contacts_rooms_and_.py +133 -0
  36. slidge/db/alembic/versions/aa9d82a7f6ef_db_creation.py +76 -0
  37. slidge/db/alembic/versions/b33993e87db3_move_everything_to_persistent_db.py +214 -0
  38. slidge/db/alembic/versions/e91195719c2c_store_users_avatars_persistently.py +26 -0
  39. slidge/db/avatar.py +224 -0
  40. slidge/db/meta.py +65 -0
  41. slidge/db/models.py +365 -0
  42. slidge/db/store.py +976 -0
  43. slidge/group/archive.py +13 -14
  44. slidge/group/bookmarks.py +59 -56
  45. slidge/group/participant.py +81 -29
  46. slidge/group/room.py +242 -142
  47. slidge/main.py +201 -0
  48. slidge/migration.py +30 -0
  49. slidge/slixfix/__init__.py +35 -2
  50. slidge/slixfix/roster.py +11 -4
  51. slidge/slixfix/xep_0292/vcard4.py +1 -0
  52. slidge/util/db.py +1 -47
  53. slidge/util/test.py +21 -4
  54. slidge/util/types.py +24 -4
  55. {slidge-0.1.2.dist-info → slidge-0.2.0a0.dist-info}/METADATA +3 -1
  56. slidge-0.2.0a0.dist-info/RECORD +108 -0
  57. slidge/core/cache.py +0 -183
  58. slidge/util/schema.sql +0 -126
  59. slidge/util/sql.py +0 -508
  60. slidge-0.1.2.dist-info/RECORD +0 -96
  61. {slidge-0.1.2.dist-info → slidge-0.2.0a0.dist-info}/LICENSE +0 -0
  62. {slidge-0.1.2.dist-info → slidge-0.2.0a0.dist-info}/WHEEL +0 -0
  63. {slidge-0.1.2.dist-info → slidge-0.2.0a0.dist-info}/entry_points.txt +0 -0
slidge/db/store.py ADDED
@@ -0,0 +1,976 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import warnings
6
+ from contextlib import contextmanager
7
+ from datetime import datetime, timedelta, timezone
8
+ from typing import TYPE_CHECKING, Collection, Iterator, Optional, Type
9
+
10
+ from slixmpp import JID, Iq, Message, Presence
11
+ from slixmpp.exceptions import XMPPError
12
+ from sqlalchemy import Engine, delete, select, update
13
+ from sqlalchemy.orm import Session, attributes
14
+
15
+ from ..util.archive_msg import HistoryMessage
16
+ from ..util.types import URL, CachedPresence
17
+ from ..util.types import Hat as HatTuple
18
+ from ..util.types import MamMetadata, MucAffiliation, MucRole
19
+ from .meta import Base
20
+ from .models import (
21
+ ArchivedMessage,
22
+ Attachment,
23
+ Avatar,
24
+ Contact,
25
+ ContactSent,
26
+ GatewayUser,
27
+ Hat,
28
+ LegacyIdsMulti,
29
+ Participant,
30
+ Room,
31
+ XmppIdsMulti,
32
+ XmppToLegacyEnum,
33
+ XmppToLegacyIds,
34
+ )
35
+
36
+ if TYPE_CHECKING:
37
+ from ..contact.contact import LegacyContact
38
+ from ..group.participant import LegacyParticipant
39
+ from ..group.room import LegacyMUC
40
+
41
+
42
+ class EngineMixin:
43
+ def __init__(self, engine: Engine):
44
+ self._engine = engine
45
+
46
+ @contextmanager
47
+ def session(self, **session_kwargs) -> Iterator[Session]:
48
+ global _session
49
+ if _session is not None:
50
+ yield _session
51
+ return
52
+ with Session(self._engine, **session_kwargs) as session:
53
+ _session = session
54
+ try:
55
+ yield session
56
+ finally:
57
+ _session = None
58
+
59
+
60
+ class UpdatedMixin(EngineMixin):
61
+ model: Type[Base] = NotImplemented
62
+
63
+ def __init__(self, *a, **kw):
64
+ super().__init__(*a, **kw)
65
+ with self.session() as session:
66
+ session.execute(update(self.model).values(updated=False)) # type:ignore
67
+ session.commit()
68
+
69
+
70
+ class SlidgeStore(EngineMixin):
71
+ def __init__(self, engine: Engine):
72
+ super().__init__(engine)
73
+ self.users = UserStore(engine)
74
+ self.avatars = AvatarStore(engine)
75
+ self.contacts = ContactStore(engine)
76
+ self.mam = MAMStore(engine)
77
+ self.multi = MultiStore(engine)
78
+ self.attachments = AttachmentStore(engine)
79
+ self.rooms = RoomStore(engine)
80
+ self.sent = SentStore(engine)
81
+ self.participants = ParticipantStore(engine)
82
+
83
+
84
+ class UserStore(EngineMixin):
85
+ def new(self, jid: JID, legacy_module_data: dict) -> GatewayUser:
86
+ if jid.resource:
87
+ jid = JID(jid.bare)
88
+ with self.session(expire_on_commit=False) as session:
89
+ user = session.execute(
90
+ select(GatewayUser).where(GatewayUser.jid == jid)
91
+ ).scalar()
92
+ if user is not None:
93
+ return user
94
+ user = GatewayUser(jid=jid, legacy_module_data=legacy_module_data)
95
+ session.add(user)
96
+ session.commit()
97
+ return user
98
+
99
+ def update(self, user: GatewayUser):
100
+ # https://github.com/sqlalchemy/sqlalchemy/discussions/6473
101
+ attributes.flag_modified(user, "legacy_module_data")
102
+ attributes.flag_modified(user, "preferences")
103
+ with self.session() as session:
104
+ session.add(user)
105
+ session.commit()
106
+
107
+ def get_all(self) -> Iterator[GatewayUser]:
108
+ with self.session() as session:
109
+ yield from session.execute(select(GatewayUser)).scalars()
110
+
111
+ def get(self, jid: JID) -> Optional[GatewayUser]:
112
+ with self.session() as session:
113
+ return session.execute(
114
+ select(GatewayUser).where(GatewayUser.jid == jid.bare)
115
+ ).scalar()
116
+
117
+ def get_by_stanza(self, stanza: Iq | Message | Presence) -> Optional[GatewayUser]:
118
+ return self.get(stanza.get_from())
119
+
120
+ def delete(self, jid: JID) -> None:
121
+ with self.session() as session:
122
+ session.delete(self.get(jid))
123
+ session.commit()
124
+
125
+
126
+ class AvatarStore(EngineMixin):
127
+ def get_by_url(self, url: URL | str) -> Optional[Avatar]:
128
+ with self.session() as session:
129
+ return session.execute(select(Avatar).where(Avatar.url == url)).scalar()
130
+
131
+ def get_by_legacy_id(self, legacy_id: str) -> Optional[Avatar]:
132
+ with self.session() as session:
133
+ return session.execute(
134
+ select(Avatar).where(Avatar.legacy_id == legacy_id)
135
+ ).scalar()
136
+
137
+ def store_jid(self, sha: str, jid: JID) -> None:
138
+ with self.session() as session:
139
+ pk = session.execute(select(Avatar.id).where(Avatar.hash == sha)).scalar()
140
+ if pk is None:
141
+ warnings.warn("avatar not found")
142
+ return
143
+ session.execute(
144
+ update(Contact).where(Contact.jid == jid.bare).values(avatar_id=pk)
145
+ )
146
+ session.execute(
147
+ update(Room).where(Room.jid == jid.bare).values(avatar_id=pk)
148
+ )
149
+ session.commit()
150
+
151
+ def get_by_jid(self, jid: JID) -> Optional[Avatar]:
152
+ with self.session() as session:
153
+ avatar = session.execute(
154
+ select(Avatar).where(Avatar.contacts.any(Contact.jid == jid))
155
+ ).scalar()
156
+ if avatar is not None:
157
+ return avatar
158
+ return session.execute(
159
+ select(Avatar).where(Avatar.rooms.any(Room.jid == jid))
160
+ ).scalar()
161
+
162
+ def get_by_pk(self, pk: int) -> Optional[Avatar]:
163
+ with self.session() as session:
164
+ return session.execute(select(Avatar).where(Avatar.id == pk)).scalar()
165
+
166
+
167
+ class SentStore(EngineMixin):
168
+ def set_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
169
+ with self.session() as session:
170
+ msg = XmppToLegacyIds(
171
+ user_account_id=user_pk,
172
+ legacy_id=legacy_id,
173
+ xmpp_id=xmpp_id,
174
+ type=XmppToLegacyEnum.DM,
175
+ )
176
+ session.add(msg)
177
+ session.commit()
178
+
179
+ def get_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]:
180
+ with self.session() as session:
181
+ return session.execute(
182
+ select(XmppToLegacyIds.xmpp_id)
183
+ .where(XmppToLegacyIds.user_account_id == user_pk)
184
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
185
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM)
186
+ ).scalar()
187
+
188
+ def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
189
+ with self.session() as session:
190
+ return session.execute(
191
+ select(XmppToLegacyIds.legacy_id)
192
+ .where(XmppToLegacyIds.user_account_id == user_pk)
193
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
194
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM)
195
+ ).scalar()
196
+
197
+ def set_group_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
198
+ with self.session() as session:
199
+ msg = XmppToLegacyIds(
200
+ user_account_id=user_pk,
201
+ legacy_id=legacy_id,
202
+ xmpp_id=xmpp_id,
203
+ type=XmppToLegacyEnum.GROUP_CHAT,
204
+ )
205
+ session.add(msg)
206
+ session.commit()
207
+
208
+ def get_group_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]:
209
+ with self.session() as session:
210
+ return session.execute(
211
+ select(XmppToLegacyIds.xmpp_id)
212
+ .where(XmppToLegacyIds.user_account_id == user_pk)
213
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
214
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT)
215
+ ).scalar()
216
+
217
+ def get_group_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
218
+ with self.session() as session:
219
+ return session.execute(
220
+ select(XmppToLegacyIds.legacy_id)
221
+ .where(XmppToLegacyIds.user_account_id == user_pk)
222
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
223
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT)
224
+ ).scalar()
225
+
226
+ def set_thread(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
227
+ with self.session() as session:
228
+ msg = XmppToLegacyIds(
229
+ user_account_id=user_pk,
230
+ legacy_id=legacy_id,
231
+ xmpp_id=xmpp_id,
232
+ type=XmppToLegacyEnum.THREAD,
233
+ )
234
+ session.add(msg)
235
+ session.commit()
236
+
237
+ def get_legacy_thread(self, user_pk: int, xmpp_id: str) -> Optional[str]:
238
+ with self.session() as session:
239
+ return session.execute(
240
+ select(XmppToLegacyIds.legacy_id)
241
+ .where(XmppToLegacyIds.user_account_id == user_pk)
242
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
243
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.THREAD)
244
+ ).scalar()
245
+
246
+ def get_xmpp_thread(self, user_pk: int, legacy_id: str) -> Optional[str]:
247
+ with self.session() as session:
248
+ return session.execute(
249
+ select(XmppToLegacyIds.xmpp_id)
250
+ .where(XmppToLegacyIds.user_account_id == user_pk)
251
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
252
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.THREAD)
253
+ ).scalar()
254
+
255
+ def was_sent_by_user(self, user_pk: int, legacy_id: str) -> bool:
256
+ with self.session() as session:
257
+ return (
258
+ session.execute(
259
+ select(XmppToLegacyIds.legacy_id)
260
+ .where(XmppToLegacyIds.user_account_id == user_pk)
261
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
262
+ ).scalar()
263
+ is not None
264
+ )
265
+
266
+
267
+ class ContactStore(UpdatedMixin):
268
+ model = Contact
269
+
270
+ def __init__(self, *a, **k):
271
+ super().__init__(*a, **k)
272
+ with self.session() as session:
273
+ session.execute(update(Contact).values(cached_presence=False))
274
+ session.commit()
275
+
276
+ def add(self, user_pk: int, legacy_id: str, contact_jid: JID) -> int:
277
+ with self.session() as session:
278
+ existing = session.execute(
279
+ select(Contact)
280
+ .where(Contact.legacy_id == legacy_id)
281
+ .where(Contact.jid == contact_jid.bare)
282
+ ).scalar()
283
+ if existing is not None:
284
+ return existing.id
285
+ contact = Contact(
286
+ jid=contact_jid.bare, legacy_id=legacy_id, user_account_id=user_pk
287
+ )
288
+ session.add(contact)
289
+ session.commit()
290
+ return contact.id
291
+
292
+ def get_all(self, user_pk: int) -> Iterator[Contact]:
293
+ with self.session() as session:
294
+ yield from session.execute(
295
+ select(Contact).where(Contact.user_account_id == user_pk)
296
+ ).scalars()
297
+
298
+ def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Contact]:
299
+ with self.session() as session:
300
+ return session.execute(
301
+ select(Contact)
302
+ .where(Contact.jid == jid.bare)
303
+ .where(Contact.user_account_id == user_pk)
304
+ ).scalar()
305
+
306
+ def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Contact]:
307
+ with self.session() as session:
308
+ return session.execute(
309
+ select(Contact)
310
+ .where(Contact.legacy_id == legacy_id)
311
+ .where(Contact.user_account_id == user_pk)
312
+ ).scalar()
313
+
314
+ def update_nick(self, contact_pk: int, nick: Optional[str]) -> None:
315
+ with self.session() as session:
316
+ session.execute(
317
+ update(Contact).where(Contact.id == contact_pk).values(nick=nick)
318
+ )
319
+ session.commit()
320
+
321
+ def get_presence(self, contact_pk: int) -> Optional[CachedPresence]:
322
+ with self.session() as session:
323
+ presence = session.execute(
324
+ select(
325
+ Contact.last_seen,
326
+ Contact.ptype,
327
+ Contact.pstatus,
328
+ Contact.pshow,
329
+ Contact.cached_presence,
330
+ ).where(Contact.id == contact_pk)
331
+ ).first()
332
+ if presence is None or not presence[-1]:
333
+ return None
334
+ return CachedPresence(*presence[:-1])
335
+
336
+ def set_presence(self, contact_pk: int, presence: CachedPresence) -> None:
337
+ with self.session() as session:
338
+ session.execute(
339
+ update(Contact)
340
+ .where(Contact.id == contact_pk)
341
+ .values(**presence._asdict(), cached_presence=True)
342
+ )
343
+ session.commit()
344
+
345
+ def reset_presence(self, contact_pk: int):
346
+ with self.session() as session:
347
+ session.execute(
348
+ update(Contact)
349
+ .where(Contact.id == contact_pk)
350
+ .values(
351
+ last_seen=None,
352
+ ptype=None,
353
+ pstatus=None,
354
+ pshow=None,
355
+ cached_presence=False,
356
+ )
357
+ )
358
+ session.commit()
359
+
360
+ def set_avatar(self, contact_pk: int, avatar_pk: Optional[int]):
361
+ with self.session() as session:
362
+ session.execute(
363
+ update(Contact)
364
+ .where(Contact.id == contact_pk)
365
+ .values(avatar_id=avatar_pk)
366
+ )
367
+ session.commit()
368
+
369
+ def get_avatar_legacy_id(self, contact_pk: int) -> Optional[str]:
370
+ with self.session() as session:
371
+ contact = session.execute(
372
+ select(Contact).where(Contact.id == contact_pk)
373
+ ).scalar()
374
+ if contact is None or contact.avatar is None:
375
+ return None
376
+ return contact.avatar.legacy_id
377
+
378
+ def update(self, contact: "LegacyContact"):
379
+ with self.session() as session:
380
+ session.execute(
381
+ update(Contact)
382
+ .where(Contact.id == contact.contact_pk)
383
+ .values(
384
+ nick=contact.name,
385
+ is_friend=contact.is_friend,
386
+ added_to_roster=contact.added_to_roster,
387
+ updated=True,
388
+ extra_attributes=contact.serialize_extra_attributes(),
389
+ )
390
+ )
391
+ session.commit()
392
+
393
+ def add_to_sent(self, contact_pk: int, msg_id: str) -> None:
394
+ with self.session() as session:
395
+ new = ContactSent(contact_id=contact_pk, msg_id=msg_id)
396
+ session.add(new)
397
+ session.commit()
398
+
399
+ def pop_sent_up_to(self, contact_pk: int, msg_id: str) -> list[str]:
400
+ result = []
401
+ to_del = []
402
+ with self.session() as session:
403
+ for row in session.execute(
404
+ select(ContactSent)
405
+ .where(ContactSent.contact_id == contact_pk)
406
+ .order_by(ContactSent.id)
407
+ ).scalars():
408
+ to_del.append(row.id)
409
+ result.append(row.msg_id)
410
+ if row.msg_id == msg_id:
411
+ break
412
+ for row_id in to_del:
413
+ session.execute(delete(ContactSent).where(ContactSent.id == row_id))
414
+ return result
415
+
416
+ def set_friend(self, contact_pk: int, is_friend: bool) -> None:
417
+ with self.session() as session:
418
+ session.execute(
419
+ update(Contact)
420
+ .where(Contact.id == contact_pk)
421
+ .values(is_friend=is_friend)
422
+ )
423
+ session.commit()
424
+
425
+
426
+ class MAMStore(EngineMixin):
427
+ def nuke_older_than(self, days: int) -> None:
428
+ with self.session() as session:
429
+ session.execute(
430
+ delete(ArchivedMessage).where(
431
+ ArchivedMessage.timestamp > datetime.now() - timedelta(days=days)
432
+ )
433
+ )
434
+ session.commit()
435
+
436
+ def add_message(self, room_pk: int, message: HistoryMessage) -> None:
437
+ with self.session() as session:
438
+ existing = session.execute(
439
+ select(ArchivedMessage)
440
+ .where(ArchivedMessage.room_id == room_pk)
441
+ .where(ArchivedMessage.stanza_id == message.id)
442
+ ).scalar()
443
+ if existing is not None:
444
+ log.debug("Updating message %s in room %s", message.id, room_pk)
445
+ existing.timestamp = message.when
446
+ existing.stanza = str(message.stanza)
447
+ existing.author_jid = message.stanza.get_from()
448
+ session.add(existing)
449
+ session.commit()
450
+ return
451
+ mam_msg = ArchivedMessage(
452
+ stanza_id=message.id,
453
+ timestamp=message.when,
454
+ stanza=str(message.stanza),
455
+ author_jid=message.stanza.get_from(),
456
+ room_id=room_pk,
457
+ )
458
+ session.add(mam_msg)
459
+ session.commit()
460
+
461
+ def get_messages(
462
+ self,
463
+ room_pk: int,
464
+ start_date: Optional[datetime] = None,
465
+ end_date: Optional[datetime] = None,
466
+ before_id: Optional[str] = None,
467
+ after_id: Optional[str] = None,
468
+ ids: Collection[str] = (),
469
+ last_page_n: Optional[int] = None,
470
+ sender: Optional[str] = None,
471
+ flip=False,
472
+ ) -> Iterator[HistoryMessage]:
473
+
474
+ with self.session() as session:
475
+ q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
476
+ if start_date is not None:
477
+ q = q.where(ArchivedMessage.timestamp >= start_date)
478
+ if end_date is not None:
479
+ q = q.where(ArchivedMessage.timestamp <= end_date)
480
+ if before_id is not None:
481
+ stamp = session.execute(
482
+ select(ArchivedMessage.timestamp).where(
483
+ ArchivedMessage.stanza_id == before_id
484
+ )
485
+ ).scalar()
486
+ if stamp is None:
487
+ raise XMPPError(
488
+ "item-not-found",
489
+ f"Message {before_id} not found",
490
+ )
491
+ q = q.where(ArchivedMessage.timestamp < stamp)
492
+ if after_id is not None:
493
+ stamp = session.execute(
494
+ select(ArchivedMessage.timestamp).where(
495
+ ArchivedMessage.stanza_id == after_id
496
+ )
497
+ ).scalar()
498
+ if stamp is None:
499
+ raise XMPPError(
500
+ "item-not-found",
501
+ f"Message {after_id} not found",
502
+ )
503
+ q = q.where(ArchivedMessage.timestamp > stamp)
504
+ if ids:
505
+ q = q.filter(ArchivedMessage.stanza_id.in_(ids))
506
+ if sender is not None:
507
+ q = q.where(ArchivedMessage.author_jid == sender)
508
+ if flip:
509
+ q = q.order_by(ArchivedMessage.timestamp.desc())
510
+ else:
511
+ q = q.order_by(ArchivedMessage.timestamp.asc())
512
+ msgs = list(session.execute(q).scalars())
513
+ if ids and len(msgs) != len(ids):
514
+ raise XMPPError(
515
+ "item-not-found",
516
+ "One of the requested messages IDs could not be found "
517
+ "with the given constraints.",
518
+ )
519
+ if last_page_n is not None:
520
+ if flip:
521
+ msgs = msgs[:last_page_n]
522
+ else:
523
+ msgs = msgs[-last_page_n:]
524
+ for h in msgs:
525
+ yield HistoryMessage(
526
+ stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=timezone.utc)
527
+ )
528
+
529
+ def get_first_and_last(self, room_pk: int) -> list[MamMetadata]:
530
+ r = []
531
+ with self.session() as session:
532
+ first = session.execute(
533
+ select(ArchivedMessage.stanza_id, ArchivedMessage.timestamp)
534
+ .where(ArchivedMessage.room_id == room_pk)
535
+ .order_by(ArchivedMessage.timestamp.asc())
536
+ ).first()
537
+ if first is not None:
538
+ r.append(MamMetadata(*first))
539
+ last = session.execute(
540
+ select(ArchivedMessage.stanza_id, ArchivedMessage.timestamp)
541
+ .where(ArchivedMessage.room_id == room_pk)
542
+ .order_by(ArchivedMessage.timestamp.desc())
543
+ ).first()
544
+ if last is not None:
545
+ r.append(MamMetadata(*last))
546
+ return r
547
+
548
+
549
+ class MultiStore(EngineMixin):
550
+ def get_xmpp_ids(self, user_pk: int, xmpp_id: str) -> list[str]:
551
+ with self.session() as session:
552
+ multi = session.execute(
553
+ select(XmppIdsMulti)
554
+ .where(XmppIdsMulti.xmpp_id == xmpp_id)
555
+ .where(XmppIdsMulti.user_account_id == user_pk)
556
+ ).scalar()
557
+ if multi is None:
558
+ return []
559
+ return [m.xmpp_id for m in multi.legacy_ids_multi.xmpp_ids]
560
+
561
+ def set_xmpp_ids(
562
+ self, user_pk: int, legacy_msg_id: str, xmpp_ids: list[str], fail=False
563
+ ) -> None:
564
+ with self.session() as session:
565
+ existing = session.execute(
566
+ select(LegacyIdsMulti)
567
+ .where(LegacyIdsMulti.user_account_id == user_pk)
568
+ .where(LegacyIdsMulti.legacy_id == legacy_msg_id)
569
+ ).scalar()
570
+ if existing is not None:
571
+ if fail:
572
+ raise
573
+ log.warning("Resetting multi for %s", legacy_msg_id)
574
+ session.execute(
575
+ delete(LegacyIdsMulti)
576
+ .where(LegacyIdsMulti.user_account_id == user_pk)
577
+ .where(LegacyIdsMulti.legacy_id == legacy_msg_id)
578
+ )
579
+ for i in xmpp_ids:
580
+ session.execute(
581
+ delete(XmppIdsMulti)
582
+ .where(XmppIdsMulti.user_account_id == user_pk)
583
+ .where(XmppIdsMulti.xmpp_id == i)
584
+ )
585
+ session.commit()
586
+ self.set_xmpp_ids(user_pk, legacy_msg_id, xmpp_ids, True)
587
+ return
588
+
589
+ row = LegacyIdsMulti(
590
+ user_account_id=user_pk,
591
+ legacy_id=legacy_msg_id,
592
+ xmpp_ids=[
593
+ XmppIdsMulti(user_account_id=user_pk, xmpp_id=i)
594
+ for i in xmpp_ids
595
+ if i
596
+ ],
597
+ )
598
+ session.add(row)
599
+ session.commit()
600
+
601
+ def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
602
+ with self.session() as session:
603
+ multi = session.execute(
604
+ select(XmppIdsMulti)
605
+ .where(XmppIdsMulti.xmpp_id == xmpp_id)
606
+ .where(XmppIdsMulti.user_account_id == user_pk)
607
+ ).scalar()
608
+ if multi is None:
609
+ return None
610
+ return multi.legacy_ids_multi.legacy_id
611
+
612
+
613
+ class AttachmentStore(EngineMixin):
614
+ def get_url(self, legacy_file_id: str) -> Optional[str]:
615
+ with self.session() as session:
616
+ return session.execute(
617
+ select(Attachment.url).where(
618
+ Attachment.legacy_file_id == legacy_file_id
619
+ )
620
+ ).scalar()
621
+
622
+ def set_url(self, user_pk: int, legacy_file_id: str, url: str) -> None:
623
+ with self.session() as session:
624
+ att = session.execute(
625
+ select(Attachment)
626
+ .where(Attachment.legacy_file_id == legacy_file_id)
627
+ .where(Attachment.user_account_id == user_pk)
628
+ ).scalar()
629
+ if att is None:
630
+ att = Attachment(
631
+ legacy_file_id=legacy_file_id, url=url, user_account_id=user_pk
632
+ )
633
+ session.add(att)
634
+ else:
635
+ att.url = url
636
+ session.commit()
637
+
638
+ def get_sims(self, url: str) -> Optional[str]:
639
+ with self.session() as session:
640
+ return session.execute(
641
+ select(Attachment.sims).where(Attachment.url == url)
642
+ ).scalar()
643
+
644
+ def set_sims(self, url: str, sims: str) -> None:
645
+ with self.session() as session:
646
+ session.execute(
647
+ update(Attachment).where(Attachment.url == url).values(sims=sims)
648
+ )
649
+ session.commit()
650
+
651
+ def get_sfs(self, url: str) -> Optional[str]:
652
+ with self.session() as session:
653
+ return session.execute(
654
+ select(Attachment.sfs).where(Attachment.url == url)
655
+ ).scalar()
656
+
657
+ def set_sfs(self, url: str, sfs: str) -> None:
658
+ with self.session() as session:
659
+ session.execute(
660
+ update(Attachment).where(Attachment.url == url).values(sfs=sfs)
661
+ )
662
+ session.commit()
663
+
664
+ def remove(self, legacy_file_id: str) -> None:
665
+ with self.session() as session:
666
+ session.execute(
667
+ delete(Attachment).where(Attachment.legacy_file_id == legacy_file_id)
668
+ )
669
+ session.commit()
670
+
671
+
672
+ class RoomStore(UpdatedMixin):
673
+ model = Room
674
+
675
+ def __init__(self, *a, **kw):
676
+ super().__init__(*a, **kw)
677
+ with self.session() as session:
678
+ session.execute(
679
+ update(Room).values(subject_setter_id=None, user_resources=None)
680
+ )
681
+ session.commit()
682
+
683
+ def add(self, user_pk: int, legacy_id: str, jid: JID) -> int:
684
+ if jid.resource:
685
+ raise TypeError
686
+ with self.session() as session:
687
+ existing = session.execute(
688
+ select(Room.id)
689
+ .where(Room.user_account_id == user_pk)
690
+ .where(Room.legacy_id == legacy_id)
691
+ ).scalar()
692
+ if existing is not None:
693
+ return existing
694
+ room = Room(jid=jid, user_account_id=user_pk, legacy_id=legacy_id)
695
+ session.add(room)
696
+ session.commit()
697
+ return room.id
698
+
699
+ def set_avatar(self, room_pk: int, avatar_pk: int) -> None:
700
+ with self.session() as session:
701
+ session.execute(
702
+ update(Room).where(Room.id == room_pk).values(avatar_id=avatar_pk)
703
+ )
704
+ session.commit()
705
+
706
+ def get_avatar_legacy_id(self, room_pk: int) -> Optional[str]:
707
+ with self.session() as session:
708
+ room = session.execute(select(Room).where(Room.id == room_pk)).scalar()
709
+ if room is None or room.avatar is None:
710
+ return None
711
+ return room.avatar.legacy_id
712
+
713
+ def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Room]:
714
+ if jid.resource:
715
+ raise TypeError
716
+ with self.session() as session:
717
+ return session.execute(
718
+ select(Room)
719
+ .where(Room.user_account_id == user_pk)
720
+ .where(Room.jid == jid)
721
+ ).scalar()
722
+
723
+ def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Room]:
724
+ with self.session() as session:
725
+ return session.execute(
726
+ select(Room)
727
+ .where(Room.user_account_id == user_pk)
728
+ .where(Room.legacy_id == legacy_id)
729
+ ).scalar()
730
+
731
+ def update(self, room: "LegacyMUC"):
732
+ from slidge.contact import LegacyContact
733
+
734
+ with self.session() as session:
735
+ if room.subject_setter is None:
736
+ subject_setter_id = None
737
+ elif isinstance(room.subject_setter, str):
738
+ subject_setter_id = None
739
+ elif isinstance(room.subject_setter, LegacyContact):
740
+ subject_setter_id = None
741
+ elif room.subject_setter.is_system:
742
+ subject_setter_id = None
743
+ else:
744
+ subject_setter_id = room.subject_setter.pk
745
+
746
+ session.execute(
747
+ update(Room)
748
+ .where(Room.id == room.pk)
749
+ .values(
750
+ updated=True,
751
+ extra_attributes=room.serialize_extra_attributes(),
752
+ name=room.name,
753
+ description=room.description,
754
+ user_resources=(
755
+ None
756
+ if not room._user_resources
757
+ else json.dumps(list(room._user_resources))
758
+ ),
759
+ muc_type=room.type,
760
+ subject=room.subject,
761
+ subject_date=room.subject_date,
762
+ subject_setter_id=subject_setter_id,
763
+ participants_filled=room._participants_filled,
764
+ n_participants=room._n_participants,
765
+ )
766
+ )
767
+ session.commit()
768
+
769
+ def update_subject_date(
770
+ self, room_pk: int, subject_date: Optional[datetime]
771
+ ) -> None:
772
+ with self.session() as session:
773
+ session.execute(
774
+ update(Room).where(Room.id == room_pk).values(subject_date=subject_date)
775
+ )
776
+ session.commit()
777
+
778
+ def update_subject(self, room_pk: int, subject: Optional[str]) -> None:
779
+ with self.session() as session:
780
+ session.execute(
781
+ update(Room).where(Room.id == room_pk).values(subject=subject)
782
+ )
783
+ session.commit()
784
+
785
+ def update_description(self, room_pk: int, desc: Optional[str]) -> None:
786
+ with self.session() as session:
787
+ session.execute(
788
+ update(Room).where(Room.id == room_pk).values(description=desc)
789
+ )
790
+ session.commit()
791
+
792
+ def update_n_participants(self, room_pk: int, n: Optional[int]) -> None:
793
+ with self.session() as session:
794
+ session.execute(
795
+ update(Room).where(Room.id == room_pk).values(n_participants=n)
796
+ )
797
+ session.commit()
798
+
799
+ def delete(self, room_pk: int) -> None:
800
+ with self.session() as session:
801
+ session.execute(delete(Room).where(Room.id == room_pk))
802
+ session.execute(delete(Participant).where(Participant.room_id == room_pk))
803
+ session.commit()
804
+
805
+ def set_resource(self, room_pk: int, resources: set[str]) -> None:
806
+ with self.session() as session:
807
+ session.execute(
808
+ update(Room)
809
+ .where(Room.id == room_pk)
810
+ .values(
811
+ user_resources=(
812
+ None if not resources else json.dumps(list(resources))
813
+ )
814
+ )
815
+ )
816
+ session.commit()
817
+
818
+ def nickname_is_available(self, room_pk: int, nickname: str) -> bool:
819
+ with self.session() as session:
820
+ return (
821
+ session.execute(
822
+ select(Participant)
823
+ .where(Participant.room_id == room_pk)
824
+ .where(Participant.nickname == nickname)
825
+ ).scalar()
826
+ is None
827
+ )
828
+
829
+ def get_all(self, user_pk: int) -> Iterator[Room]:
830
+ with self.session() as session:
831
+ yield from session.execute(
832
+ select(Room).where(Room.user_account_id == user_pk)
833
+ ).scalars()
834
+
835
+
836
+ class ParticipantStore(EngineMixin):
837
+ def __init__(self, *a, **kw):
838
+ super().__init__(*a, **kw)
839
+ with self.session() as session:
840
+ session.execute(delete(Participant))
841
+ session.execute(delete(Hat))
842
+ session.commit()
843
+
844
+ def add(self, room_pk: int, nickname: str) -> int:
845
+ with self.session() as session:
846
+ existing = session.execute(
847
+ select(Participant.id)
848
+ .where(Participant.room_id == room_pk)
849
+ .where(Participant.nickname == nickname)
850
+ ).scalar()
851
+ if existing is not None:
852
+ return existing
853
+ participant = Participant(room_id=room_pk, nickname=nickname)
854
+ session.add(participant)
855
+ session.commit()
856
+ return participant.id
857
+
858
+ def get_by_nickname(self, room_pk: int, nickname: str) -> Optional[Participant]:
859
+ with self.session() as session:
860
+ return session.execute(
861
+ select(Participant)
862
+ .where(Participant.room_id == room_pk)
863
+ .where(Participant.nickname == nickname)
864
+ ).scalar()
865
+
866
+ def get_by_resource(self, room_pk: int, resource: str) -> Optional[Participant]:
867
+ with self.session() as session:
868
+ return session.execute(
869
+ select(Participant)
870
+ .where(Participant.room_id == room_pk)
871
+ .where(Participant.resource == resource)
872
+ ).scalar()
873
+
874
+ def get_by_contact(self, room_pk: int, contact_pk: int) -> Optional[Participant]:
875
+ with self.session() as session:
876
+ return session.execute(
877
+ select(Participant)
878
+ .where(Participant.room_id == room_pk)
879
+ .where(Participant.contact_id == contact_pk)
880
+ ).scalar()
881
+
882
+ def get_all(self, room_pk: int, user_included=True) -> Iterator[Participant]:
883
+ with self.session() as session:
884
+ q = select(Participant).where(Participant.room_id == room_pk)
885
+ if not user_included:
886
+ q = q.where(~Participant.is_user)
887
+ yield from session.execute(q).scalars()
888
+
889
+ def get_for_contact(self, contact_pk: int) -> Iterator[Participant]:
890
+ with self.session() as session:
891
+ yield from session.execute(
892
+ select(Participant).where(Participant.contact_id == contact_pk)
893
+ ).scalars()
894
+
895
+ def update(self, participant: "LegacyParticipant") -> None:
896
+ with self.session() as session:
897
+ session.execute(
898
+ update(Participant)
899
+ .where(Participant.id == participant.pk)
900
+ .values(
901
+ resource=participant.jid.resource,
902
+ affiliation=participant.affiliation,
903
+ role=participant.role,
904
+ presence_sent=participant._presence_sent, # type:ignore
905
+ # hats=[self.add_hat(h.uri, h.title) for h in participant._hats],
906
+ is_user=participant.is_user,
907
+ contact_id=(
908
+ None
909
+ if participant.contact is None
910
+ else participant.contact.contact_pk
911
+ ),
912
+ )
913
+ )
914
+ session.commit()
915
+
916
+ def add_hat(self, uri: str, title: str) -> Hat:
917
+ with self.session() as session:
918
+ existing = session.execute(
919
+ select(Hat).where(Hat.uri == uri).where(Hat.title == title)
920
+ ).scalar()
921
+ if existing is not None:
922
+ return existing
923
+ hat = Hat(uri=uri, title=title)
924
+ session.add(hat)
925
+ session.commit()
926
+ return hat
927
+
928
+ def set_presence_sent(self, participant_pk: int) -> None:
929
+ with self.session() as session:
930
+ session.execute(
931
+ update(Participant)
932
+ .where(Participant.id == participant_pk)
933
+ .values(presence_sent=True)
934
+ )
935
+ session.commit()
936
+
937
+ def set_affiliation(self, participant_pk: int, affiliation: MucAffiliation) -> None:
938
+ with self.session() as session:
939
+ session.execute(
940
+ update(Participant)
941
+ .where(Participant.id == participant_pk)
942
+ .values(affiliation=affiliation)
943
+ )
944
+ session.commit()
945
+
946
+ def set_role(self, participant_pk: int, role: MucRole) -> None:
947
+ with self.session() as session:
948
+ session.execute(
949
+ update(Participant)
950
+ .where(Participant.id == participant_pk)
951
+ .values(role=role)
952
+ )
953
+ session.commit()
954
+
955
+ def set_hats(self, participant_pk: int, hats: list[HatTuple]) -> None:
956
+ with self.session() as session:
957
+ part = session.execute(
958
+ select(Participant).where(Participant.id == participant_pk)
959
+ ).scalar()
960
+ if part is None:
961
+ raise ValueError
962
+ part.hats.clear()
963
+ for h in hats:
964
+ hat = self.add_hat(*h)
965
+ if hat in part.hats:
966
+ continue
967
+ part.hats.append(hat)
968
+ session.commit()
969
+
970
+ def delete(self, participant_pk: int) -> None:
971
+ with self.session() as session:
972
+ session.execute(delete(Participant).where(Participant.id == participant_pk))
973
+
974
+
975
+ log = logging.getLogger(__name__)
976
+ _session: Optional[Session] = None