slidge 0.1.3__py3-none-any.whl → 0.2.0a0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (63) hide show
  1. slidge/__init__.py +3 -5
  2. slidge/__main__.py +2 -196
  3. slidge/__version__.py +5 -0
  4. slidge/command/adhoc.py +8 -1
  5. slidge/command/admin.py +5 -6
  6. slidge/command/base.py +1 -2
  7. slidge/command/register.py +32 -16
  8. slidge/command/user.py +85 -5
  9. slidge/contact/contact.py +93 -31
  10. slidge/contact/roster.py +54 -39
  11. slidge/core/config.py +13 -7
  12. slidge/core/gateway/base.py +139 -34
  13. slidge/core/gateway/disco.py +2 -4
  14. slidge/core/gateway/mam.py +1 -4
  15. slidge/core/gateway/ping.py +2 -3
  16. slidge/core/gateway/presence.py +1 -1
  17. slidge/core/gateway/registration.py +32 -21
  18. slidge/core/gateway/search.py +3 -5
  19. slidge/core/gateway/session_dispatcher.py +100 -51
  20. slidge/core/gateway/vcard_temp.py +6 -4
  21. slidge/core/mixins/__init__.py +11 -1
  22. slidge/core/mixins/attachment.py +15 -10
  23. slidge/core/mixins/avatar.py +66 -18
  24. slidge/core/mixins/base.py +8 -2
  25. slidge/core/mixins/message.py +11 -7
  26. slidge/core/mixins/message_maker.py +17 -9
  27. slidge/core/mixins/presence.py +14 -4
  28. slidge/core/pubsub.py +54 -212
  29. slidge/core/session.py +65 -33
  30. slidge/db/__init__.py +4 -0
  31. slidge/db/alembic/env.py +64 -0
  32. slidge/db/alembic/script.py.mako +26 -0
  33. slidge/db/alembic/versions/09f27f098baa_add_missing_attributes_in_room.py +36 -0
  34. slidge/db/alembic/versions/29f5280c61aa_store_subject_setter_in_room.py +37 -0
  35. slidge/db/alembic/versions/8d2ced764698_rely_on_db_to_store_contacts_rooms_and_.py +133 -0
  36. slidge/db/alembic/versions/aa9d82a7f6ef_db_creation.py +76 -0
  37. slidge/db/alembic/versions/b33993e87db3_move_everything_to_persistent_db.py +214 -0
  38. slidge/db/alembic/versions/e91195719c2c_store_users_avatars_persistently.py +26 -0
  39. slidge/db/avatar.py +224 -0
  40. slidge/db/meta.py +65 -0
  41. slidge/db/models.py +365 -0
  42. slidge/db/store.py +976 -0
  43. slidge/group/archive.py +13 -14
  44. slidge/group/bookmarks.py +59 -56
  45. slidge/group/participant.py +77 -25
  46. slidge/group/room.py +242 -142
  47. slidge/main.py +201 -0
  48. slidge/migration.py +30 -0
  49. slidge/slixfix/__init__.py +35 -2
  50. slidge/slixfix/roster.py +11 -4
  51. slidge/slixfix/xep_0292/vcard4.py +1 -0
  52. slidge/util/db.py +1 -47
  53. slidge/util/test.py +21 -4
  54. slidge/util/types.py +24 -4
  55. {slidge-0.1.3.dist-info → slidge-0.2.0a0.dist-info}/METADATA +3 -1
  56. slidge-0.2.0a0.dist-info/RECORD +108 -0
  57. slidge/core/cache.py +0 -183
  58. slidge/util/schema.sql +0 -126
  59. slidge/util/sql.py +0 -508
  60. slidge-0.1.3.dist-info/RECORD +0 -96
  61. {slidge-0.1.3.dist-info → slidge-0.2.0a0.dist-info}/LICENSE +0 -0
  62. {slidge-0.1.3.dist-info → slidge-0.2.0a0.dist-info}/WHEEL +0 -0
  63. {slidge-0.1.3.dist-info → slidge-0.2.0a0.dist-info}/entry_points.txt +0 -0
slidge/db/store.py ADDED
@@ -0,0 +1,976 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import warnings
6
+ from contextlib import contextmanager
7
+ from datetime import datetime, timedelta, timezone
8
+ from typing import TYPE_CHECKING, Collection, Iterator, Optional, Type
9
+
10
+ from slixmpp import JID, Iq, Message, Presence
11
+ from slixmpp.exceptions import XMPPError
12
+ from sqlalchemy import Engine, delete, select, update
13
+ from sqlalchemy.orm import Session, attributes
14
+
15
+ from ..util.archive_msg import HistoryMessage
16
+ from ..util.types import URL, CachedPresence
17
+ from ..util.types import Hat as HatTuple
18
+ from ..util.types import MamMetadata, MucAffiliation, MucRole
19
+ from .meta import Base
20
+ from .models import (
21
+ ArchivedMessage,
22
+ Attachment,
23
+ Avatar,
24
+ Contact,
25
+ ContactSent,
26
+ GatewayUser,
27
+ Hat,
28
+ LegacyIdsMulti,
29
+ Participant,
30
+ Room,
31
+ XmppIdsMulti,
32
+ XmppToLegacyEnum,
33
+ XmppToLegacyIds,
34
+ )
35
+
36
+ if TYPE_CHECKING:
37
+ from ..contact.contact import LegacyContact
38
+ from ..group.participant import LegacyParticipant
39
+ from ..group.room import LegacyMUC
40
+
41
+
42
+ class EngineMixin:
43
+ def __init__(self, engine: Engine):
44
+ self._engine = engine
45
+
46
+ @contextmanager
47
+ def session(self, **session_kwargs) -> Iterator[Session]:
48
+ global _session
49
+ if _session is not None:
50
+ yield _session
51
+ return
52
+ with Session(self._engine, **session_kwargs) as session:
53
+ _session = session
54
+ try:
55
+ yield session
56
+ finally:
57
+ _session = None
58
+
59
+
60
+ class UpdatedMixin(EngineMixin):
61
+ model: Type[Base] = NotImplemented
62
+
63
+ def __init__(self, *a, **kw):
64
+ super().__init__(*a, **kw)
65
+ with self.session() as session:
66
+ session.execute(update(self.model).values(updated=False)) # type:ignore
67
+ session.commit()
68
+
69
+
70
+ class SlidgeStore(EngineMixin):
71
+ def __init__(self, engine: Engine):
72
+ super().__init__(engine)
73
+ self.users = UserStore(engine)
74
+ self.avatars = AvatarStore(engine)
75
+ self.contacts = ContactStore(engine)
76
+ self.mam = MAMStore(engine)
77
+ self.multi = MultiStore(engine)
78
+ self.attachments = AttachmentStore(engine)
79
+ self.rooms = RoomStore(engine)
80
+ self.sent = SentStore(engine)
81
+ self.participants = ParticipantStore(engine)
82
+
83
+
84
+ class UserStore(EngineMixin):
85
+ def new(self, jid: JID, legacy_module_data: dict) -> GatewayUser:
86
+ if jid.resource:
87
+ jid = JID(jid.bare)
88
+ with self.session(expire_on_commit=False) as session:
89
+ user = session.execute(
90
+ select(GatewayUser).where(GatewayUser.jid == jid)
91
+ ).scalar()
92
+ if user is not None:
93
+ return user
94
+ user = GatewayUser(jid=jid, legacy_module_data=legacy_module_data)
95
+ session.add(user)
96
+ session.commit()
97
+ return user
98
+
99
+ def update(self, user: GatewayUser):
100
+ # https://github.com/sqlalchemy/sqlalchemy/discussions/6473
101
+ attributes.flag_modified(user, "legacy_module_data")
102
+ attributes.flag_modified(user, "preferences")
103
+ with self.session() as session:
104
+ session.add(user)
105
+ session.commit()
106
+
107
+ def get_all(self) -> Iterator[GatewayUser]:
108
+ with self.session() as session:
109
+ yield from session.execute(select(GatewayUser)).scalars()
110
+
111
+ def get(self, jid: JID) -> Optional[GatewayUser]:
112
+ with self.session() as session:
113
+ return session.execute(
114
+ select(GatewayUser).where(GatewayUser.jid == jid.bare)
115
+ ).scalar()
116
+
117
+ def get_by_stanza(self, stanza: Iq | Message | Presence) -> Optional[GatewayUser]:
118
+ return self.get(stanza.get_from())
119
+
120
+ def delete(self, jid: JID) -> None:
121
+ with self.session() as session:
122
+ session.delete(self.get(jid))
123
+ session.commit()
124
+
125
+
126
+ class AvatarStore(EngineMixin):
127
+ def get_by_url(self, url: URL | str) -> Optional[Avatar]:
128
+ with self.session() as session:
129
+ return session.execute(select(Avatar).where(Avatar.url == url)).scalar()
130
+
131
+ def get_by_legacy_id(self, legacy_id: str) -> Optional[Avatar]:
132
+ with self.session() as session:
133
+ return session.execute(
134
+ select(Avatar).where(Avatar.legacy_id == legacy_id)
135
+ ).scalar()
136
+
137
+ def store_jid(self, sha: str, jid: JID) -> None:
138
+ with self.session() as session:
139
+ pk = session.execute(select(Avatar.id).where(Avatar.hash == sha)).scalar()
140
+ if pk is None:
141
+ warnings.warn("avatar not found")
142
+ return
143
+ session.execute(
144
+ update(Contact).where(Contact.jid == jid.bare).values(avatar_id=pk)
145
+ )
146
+ session.execute(
147
+ update(Room).where(Room.jid == jid.bare).values(avatar_id=pk)
148
+ )
149
+ session.commit()
150
+
151
+ def get_by_jid(self, jid: JID) -> Optional[Avatar]:
152
+ with self.session() as session:
153
+ avatar = session.execute(
154
+ select(Avatar).where(Avatar.contacts.any(Contact.jid == jid))
155
+ ).scalar()
156
+ if avatar is not None:
157
+ return avatar
158
+ return session.execute(
159
+ select(Avatar).where(Avatar.rooms.any(Room.jid == jid))
160
+ ).scalar()
161
+
162
+ def get_by_pk(self, pk: int) -> Optional[Avatar]:
163
+ with self.session() as session:
164
+ return session.execute(select(Avatar).where(Avatar.id == pk)).scalar()
165
+
166
+
167
+ class SentStore(EngineMixin):
168
+ def set_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
169
+ with self.session() as session:
170
+ msg = XmppToLegacyIds(
171
+ user_account_id=user_pk,
172
+ legacy_id=legacy_id,
173
+ xmpp_id=xmpp_id,
174
+ type=XmppToLegacyEnum.DM,
175
+ )
176
+ session.add(msg)
177
+ session.commit()
178
+
179
+ def get_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]:
180
+ with self.session() as session:
181
+ return session.execute(
182
+ select(XmppToLegacyIds.xmpp_id)
183
+ .where(XmppToLegacyIds.user_account_id == user_pk)
184
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
185
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM)
186
+ ).scalar()
187
+
188
+ def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
189
+ with self.session() as session:
190
+ return session.execute(
191
+ select(XmppToLegacyIds.legacy_id)
192
+ .where(XmppToLegacyIds.user_account_id == user_pk)
193
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
194
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM)
195
+ ).scalar()
196
+
197
+ def set_group_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
198
+ with self.session() as session:
199
+ msg = XmppToLegacyIds(
200
+ user_account_id=user_pk,
201
+ legacy_id=legacy_id,
202
+ xmpp_id=xmpp_id,
203
+ type=XmppToLegacyEnum.GROUP_CHAT,
204
+ )
205
+ session.add(msg)
206
+ session.commit()
207
+
208
+ def get_group_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]:
209
+ with self.session() as session:
210
+ return session.execute(
211
+ select(XmppToLegacyIds.xmpp_id)
212
+ .where(XmppToLegacyIds.user_account_id == user_pk)
213
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
214
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT)
215
+ ).scalar()
216
+
217
+ def get_group_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
218
+ with self.session() as session:
219
+ return session.execute(
220
+ select(XmppToLegacyIds.legacy_id)
221
+ .where(XmppToLegacyIds.user_account_id == user_pk)
222
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
223
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT)
224
+ ).scalar()
225
+
226
+ def set_thread(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
227
+ with self.session() as session:
228
+ msg = XmppToLegacyIds(
229
+ user_account_id=user_pk,
230
+ legacy_id=legacy_id,
231
+ xmpp_id=xmpp_id,
232
+ type=XmppToLegacyEnum.THREAD,
233
+ )
234
+ session.add(msg)
235
+ session.commit()
236
+
237
+ def get_legacy_thread(self, user_pk: int, xmpp_id: str) -> Optional[str]:
238
+ with self.session() as session:
239
+ return session.execute(
240
+ select(XmppToLegacyIds.legacy_id)
241
+ .where(XmppToLegacyIds.user_account_id == user_pk)
242
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
243
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.THREAD)
244
+ ).scalar()
245
+
246
+ def get_xmpp_thread(self, user_pk: int, legacy_id: str) -> Optional[str]:
247
+ with self.session() as session:
248
+ return session.execute(
249
+ select(XmppToLegacyIds.xmpp_id)
250
+ .where(XmppToLegacyIds.user_account_id == user_pk)
251
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
252
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.THREAD)
253
+ ).scalar()
254
+
255
+ def was_sent_by_user(self, user_pk: int, legacy_id: str) -> bool:
256
+ with self.session() as session:
257
+ return (
258
+ session.execute(
259
+ select(XmppToLegacyIds.legacy_id)
260
+ .where(XmppToLegacyIds.user_account_id == user_pk)
261
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
262
+ ).scalar()
263
+ is not None
264
+ )
265
+
266
+
267
+ class ContactStore(UpdatedMixin):
268
+ model = Contact
269
+
270
+ def __init__(self, *a, **k):
271
+ super().__init__(*a, **k)
272
+ with self.session() as session:
273
+ session.execute(update(Contact).values(cached_presence=False))
274
+ session.commit()
275
+
276
+ def add(self, user_pk: int, legacy_id: str, contact_jid: JID) -> int:
277
+ with self.session() as session:
278
+ existing = session.execute(
279
+ select(Contact)
280
+ .where(Contact.legacy_id == legacy_id)
281
+ .where(Contact.jid == contact_jid.bare)
282
+ ).scalar()
283
+ if existing is not None:
284
+ return existing.id
285
+ contact = Contact(
286
+ jid=contact_jid.bare, legacy_id=legacy_id, user_account_id=user_pk
287
+ )
288
+ session.add(contact)
289
+ session.commit()
290
+ return contact.id
291
+
292
+ def get_all(self, user_pk: int) -> Iterator[Contact]:
293
+ with self.session() as session:
294
+ yield from session.execute(
295
+ select(Contact).where(Contact.user_account_id == user_pk)
296
+ ).scalars()
297
+
298
+ def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Contact]:
299
+ with self.session() as session:
300
+ return session.execute(
301
+ select(Contact)
302
+ .where(Contact.jid == jid.bare)
303
+ .where(Contact.user_account_id == user_pk)
304
+ ).scalar()
305
+
306
+ def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Contact]:
307
+ with self.session() as session:
308
+ return session.execute(
309
+ select(Contact)
310
+ .where(Contact.legacy_id == legacy_id)
311
+ .where(Contact.user_account_id == user_pk)
312
+ ).scalar()
313
+
314
+ def update_nick(self, contact_pk: int, nick: Optional[str]) -> None:
315
+ with self.session() as session:
316
+ session.execute(
317
+ update(Contact).where(Contact.id == contact_pk).values(nick=nick)
318
+ )
319
+ session.commit()
320
+
321
+ def get_presence(self, contact_pk: int) -> Optional[CachedPresence]:
322
+ with self.session() as session:
323
+ presence = session.execute(
324
+ select(
325
+ Contact.last_seen,
326
+ Contact.ptype,
327
+ Contact.pstatus,
328
+ Contact.pshow,
329
+ Contact.cached_presence,
330
+ ).where(Contact.id == contact_pk)
331
+ ).first()
332
+ if presence is None or not presence[-1]:
333
+ return None
334
+ return CachedPresence(*presence[:-1])
335
+
336
+ def set_presence(self, contact_pk: int, presence: CachedPresence) -> None:
337
+ with self.session() as session:
338
+ session.execute(
339
+ update(Contact)
340
+ .where(Contact.id == contact_pk)
341
+ .values(**presence._asdict(), cached_presence=True)
342
+ )
343
+ session.commit()
344
+
345
+ def reset_presence(self, contact_pk: int):
346
+ with self.session() as session:
347
+ session.execute(
348
+ update(Contact)
349
+ .where(Contact.id == contact_pk)
350
+ .values(
351
+ last_seen=None,
352
+ ptype=None,
353
+ pstatus=None,
354
+ pshow=None,
355
+ cached_presence=False,
356
+ )
357
+ )
358
+ session.commit()
359
+
360
+ def set_avatar(self, contact_pk: int, avatar_pk: Optional[int]):
361
+ with self.session() as session:
362
+ session.execute(
363
+ update(Contact)
364
+ .where(Contact.id == contact_pk)
365
+ .values(avatar_id=avatar_pk)
366
+ )
367
+ session.commit()
368
+
369
+ def get_avatar_legacy_id(self, contact_pk: int) -> Optional[str]:
370
+ with self.session() as session:
371
+ contact = session.execute(
372
+ select(Contact).where(Contact.id == contact_pk)
373
+ ).scalar()
374
+ if contact is None or contact.avatar is None:
375
+ return None
376
+ return contact.avatar.legacy_id
377
+
378
+ def update(self, contact: "LegacyContact"):
379
+ with self.session() as session:
380
+ session.execute(
381
+ update(Contact)
382
+ .where(Contact.id == contact.contact_pk)
383
+ .values(
384
+ nick=contact.name,
385
+ is_friend=contact.is_friend,
386
+ added_to_roster=contact.added_to_roster,
387
+ updated=True,
388
+ extra_attributes=contact.serialize_extra_attributes(),
389
+ )
390
+ )
391
+ session.commit()
392
+
393
+ def add_to_sent(self, contact_pk: int, msg_id: str) -> None:
394
+ with self.session() as session:
395
+ new = ContactSent(contact_id=contact_pk, msg_id=msg_id)
396
+ session.add(new)
397
+ session.commit()
398
+
399
+ def pop_sent_up_to(self, contact_pk: int, msg_id: str) -> list[str]:
400
+ result = []
401
+ to_del = []
402
+ with self.session() as session:
403
+ for row in session.execute(
404
+ select(ContactSent)
405
+ .where(ContactSent.contact_id == contact_pk)
406
+ .order_by(ContactSent.id)
407
+ ).scalars():
408
+ to_del.append(row.id)
409
+ result.append(row.msg_id)
410
+ if row.msg_id == msg_id:
411
+ break
412
+ for row_id in to_del:
413
+ session.execute(delete(ContactSent).where(ContactSent.id == row_id))
414
+ return result
415
+
416
+ def set_friend(self, contact_pk: int, is_friend: bool) -> None:
417
+ with self.session() as session:
418
+ session.execute(
419
+ update(Contact)
420
+ .where(Contact.id == contact_pk)
421
+ .values(is_friend=is_friend)
422
+ )
423
+ session.commit()
424
+
425
+
426
+ class MAMStore(EngineMixin):
427
+ def nuke_older_than(self, days: int) -> None:
428
+ with self.session() as session:
429
+ session.execute(
430
+ delete(ArchivedMessage).where(
431
+ ArchivedMessage.timestamp > datetime.now() - timedelta(days=days)
432
+ )
433
+ )
434
+ session.commit()
435
+
436
+ def add_message(self, room_pk: int, message: HistoryMessage) -> None:
437
+ with self.session() as session:
438
+ existing = session.execute(
439
+ select(ArchivedMessage)
440
+ .where(ArchivedMessage.room_id == room_pk)
441
+ .where(ArchivedMessage.stanza_id == message.id)
442
+ ).scalar()
443
+ if existing is not None:
444
+ log.debug("Updating message %s in room %s", message.id, room_pk)
445
+ existing.timestamp = message.when
446
+ existing.stanza = str(message.stanza)
447
+ existing.author_jid = message.stanza.get_from()
448
+ session.add(existing)
449
+ session.commit()
450
+ return
451
+ mam_msg = ArchivedMessage(
452
+ stanza_id=message.id,
453
+ timestamp=message.when,
454
+ stanza=str(message.stanza),
455
+ author_jid=message.stanza.get_from(),
456
+ room_id=room_pk,
457
+ )
458
+ session.add(mam_msg)
459
+ session.commit()
460
+
461
+ def get_messages(
462
+ self,
463
+ room_pk: int,
464
+ start_date: Optional[datetime] = None,
465
+ end_date: Optional[datetime] = None,
466
+ before_id: Optional[str] = None,
467
+ after_id: Optional[str] = None,
468
+ ids: Collection[str] = (),
469
+ last_page_n: Optional[int] = None,
470
+ sender: Optional[str] = None,
471
+ flip=False,
472
+ ) -> Iterator[HistoryMessage]:
473
+
474
+ with self.session() as session:
475
+ q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
476
+ if start_date is not None:
477
+ q = q.where(ArchivedMessage.timestamp >= start_date)
478
+ if end_date is not None:
479
+ q = q.where(ArchivedMessage.timestamp <= end_date)
480
+ if before_id is not None:
481
+ stamp = session.execute(
482
+ select(ArchivedMessage.timestamp).where(
483
+ ArchivedMessage.stanza_id == before_id
484
+ )
485
+ ).scalar()
486
+ if stamp is None:
487
+ raise XMPPError(
488
+ "item-not-found",
489
+ f"Message {before_id} not found",
490
+ )
491
+ q = q.where(ArchivedMessage.timestamp < stamp)
492
+ if after_id is not None:
493
+ stamp = session.execute(
494
+ select(ArchivedMessage.timestamp).where(
495
+ ArchivedMessage.stanza_id == after_id
496
+ )
497
+ ).scalar()
498
+ if stamp is None:
499
+ raise XMPPError(
500
+ "item-not-found",
501
+ f"Message {after_id} not found",
502
+ )
503
+ q = q.where(ArchivedMessage.timestamp > stamp)
504
+ if ids:
505
+ q = q.filter(ArchivedMessage.stanza_id.in_(ids))
506
+ if sender is not None:
507
+ q = q.where(ArchivedMessage.author_jid == sender)
508
+ if flip:
509
+ q = q.order_by(ArchivedMessage.timestamp.desc())
510
+ else:
511
+ q = q.order_by(ArchivedMessage.timestamp.asc())
512
+ msgs = list(session.execute(q).scalars())
513
+ if ids and len(msgs) != len(ids):
514
+ raise XMPPError(
515
+ "item-not-found",
516
+ "One of the requested messages IDs could not be found "
517
+ "with the given constraints.",
518
+ )
519
+ if last_page_n is not None:
520
+ if flip:
521
+ msgs = msgs[:last_page_n]
522
+ else:
523
+ msgs = msgs[-last_page_n:]
524
+ for h in msgs:
525
+ yield HistoryMessage(
526
+ stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=timezone.utc)
527
+ )
528
+
529
+ def get_first_and_last(self, room_pk: int) -> list[MamMetadata]:
530
+ r = []
531
+ with self.session() as session:
532
+ first = session.execute(
533
+ select(ArchivedMessage.stanza_id, ArchivedMessage.timestamp)
534
+ .where(ArchivedMessage.room_id == room_pk)
535
+ .order_by(ArchivedMessage.timestamp.asc())
536
+ ).first()
537
+ if first is not None:
538
+ r.append(MamMetadata(*first))
539
+ last = session.execute(
540
+ select(ArchivedMessage.stanza_id, ArchivedMessage.timestamp)
541
+ .where(ArchivedMessage.room_id == room_pk)
542
+ .order_by(ArchivedMessage.timestamp.desc())
543
+ ).first()
544
+ if last is not None:
545
+ r.append(MamMetadata(*last))
546
+ return r
547
+
548
+
549
+ class MultiStore(EngineMixin):
550
+ def get_xmpp_ids(self, user_pk: int, xmpp_id: str) -> list[str]:
551
+ with self.session() as session:
552
+ multi = session.execute(
553
+ select(XmppIdsMulti)
554
+ .where(XmppIdsMulti.xmpp_id == xmpp_id)
555
+ .where(XmppIdsMulti.user_account_id == user_pk)
556
+ ).scalar()
557
+ if multi is None:
558
+ return []
559
+ return [m.xmpp_id for m in multi.legacy_ids_multi.xmpp_ids]
560
+
561
+ def set_xmpp_ids(
562
+ self, user_pk: int, legacy_msg_id: str, xmpp_ids: list[str], fail=False
563
+ ) -> None:
564
+ with self.session() as session:
565
+ existing = session.execute(
566
+ select(LegacyIdsMulti)
567
+ .where(LegacyIdsMulti.user_account_id == user_pk)
568
+ .where(LegacyIdsMulti.legacy_id == legacy_msg_id)
569
+ ).scalar()
570
+ if existing is not None:
571
+ if fail:
572
+ raise
573
+ log.warning("Resetting multi for %s", legacy_msg_id)
574
+ session.execute(
575
+ delete(LegacyIdsMulti)
576
+ .where(LegacyIdsMulti.user_account_id == user_pk)
577
+ .where(LegacyIdsMulti.legacy_id == legacy_msg_id)
578
+ )
579
+ for i in xmpp_ids:
580
+ session.execute(
581
+ delete(XmppIdsMulti)
582
+ .where(XmppIdsMulti.user_account_id == user_pk)
583
+ .where(XmppIdsMulti.xmpp_id == i)
584
+ )
585
+ session.commit()
586
+ self.set_xmpp_ids(user_pk, legacy_msg_id, xmpp_ids, True)
587
+ return
588
+
589
+ row = LegacyIdsMulti(
590
+ user_account_id=user_pk,
591
+ legacy_id=legacy_msg_id,
592
+ xmpp_ids=[
593
+ XmppIdsMulti(user_account_id=user_pk, xmpp_id=i)
594
+ for i in xmpp_ids
595
+ if i
596
+ ],
597
+ )
598
+ session.add(row)
599
+ session.commit()
600
+
601
+ def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
602
+ with self.session() as session:
603
+ multi = session.execute(
604
+ select(XmppIdsMulti)
605
+ .where(XmppIdsMulti.xmpp_id == xmpp_id)
606
+ .where(XmppIdsMulti.user_account_id == user_pk)
607
+ ).scalar()
608
+ if multi is None:
609
+ return None
610
+ return multi.legacy_ids_multi.legacy_id
611
+
612
+
613
+ class AttachmentStore(EngineMixin):
614
+ def get_url(self, legacy_file_id: str) -> Optional[str]:
615
+ with self.session() as session:
616
+ return session.execute(
617
+ select(Attachment.url).where(
618
+ Attachment.legacy_file_id == legacy_file_id
619
+ )
620
+ ).scalar()
621
+
622
+ def set_url(self, user_pk: int, legacy_file_id: str, url: str) -> None:
623
+ with self.session() as session:
624
+ att = session.execute(
625
+ select(Attachment)
626
+ .where(Attachment.legacy_file_id == legacy_file_id)
627
+ .where(Attachment.user_account_id == user_pk)
628
+ ).scalar()
629
+ if att is None:
630
+ att = Attachment(
631
+ legacy_file_id=legacy_file_id, url=url, user_account_id=user_pk
632
+ )
633
+ session.add(att)
634
+ else:
635
+ att.url = url
636
+ session.commit()
637
+
638
+ def get_sims(self, url: str) -> Optional[str]:
639
+ with self.session() as session:
640
+ return session.execute(
641
+ select(Attachment.sims).where(Attachment.url == url)
642
+ ).scalar()
643
+
644
+ def set_sims(self, url: str, sims: str) -> None:
645
+ with self.session() as session:
646
+ session.execute(
647
+ update(Attachment).where(Attachment.url == url).values(sims=sims)
648
+ )
649
+ session.commit()
650
+
651
+ def get_sfs(self, url: str) -> Optional[str]:
652
+ with self.session() as session:
653
+ return session.execute(
654
+ select(Attachment.sfs).where(Attachment.url == url)
655
+ ).scalar()
656
+
657
+ def set_sfs(self, url: str, sfs: str) -> None:
658
+ with self.session() as session:
659
+ session.execute(
660
+ update(Attachment).where(Attachment.url == url).values(sfs=sfs)
661
+ )
662
+ session.commit()
663
+
664
+ def remove(self, legacy_file_id: str) -> None:
665
+ with self.session() as session:
666
+ session.execute(
667
+ delete(Attachment).where(Attachment.legacy_file_id == legacy_file_id)
668
+ )
669
+ session.commit()
670
+
671
+
672
+ class RoomStore(UpdatedMixin):
673
+ model = Room
674
+
675
+ def __init__(self, *a, **kw):
676
+ super().__init__(*a, **kw)
677
+ with self.session() as session:
678
+ session.execute(
679
+ update(Room).values(subject_setter_id=None, user_resources=None)
680
+ )
681
+ session.commit()
682
+
683
+ def add(self, user_pk: int, legacy_id: str, jid: JID) -> int:
684
+ if jid.resource:
685
+ raise TypeError
686
+ with self.session() as session:
687
+ existing = session.execute(
688
+ select(Room.id)
689
+ .where(Room.user_account_id == user_pk)
690
+ .where(Room.legacy_id == legacy_id)
691
+ ).scalar()
692
+ if existing is not None:
693
+ return existing
694
+ room = Room(jid=jid, user_account_id=user_pk, legacy_id=legacy_id)
695
+ session.add(room)
696
+ session.commit()
697
+ return room.id
698
+
699
+ def set_avatar(self, room_pk: int, avatar_pk: int) -> None:
700
+ with self.session() as session:
701
+ session.execute(
702
+ update(Room).where(Room.id == room_pk).values(avatar_id=avatar_pk)
703
+ )
704
+ session.commit()
705
+
706
+ def get_avatar_legacy_id(self, room_pk: int) -> Optional[str]:
707
+ with self.session() as session:
708
+ room = session.execute(select(Room).where(Room.id == room_pk)).scalar()
709
+ if room is None or room.avatar is None:
710
+ return None
711
+ return room.avatar.legacy_id
712
+
713
+ def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Room]:
714
+ if jid.resource:
715
+ raise TypeError
716
+ with self.session() as session:
717
+ return session.execute(
718
+ select(Room)
719
+ .where(Room.user_account_id == user_pk)
720
+ .where(Room.jid == jid)
721
+ ).scalar()
722
+
723
+ def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Room]:
724
+ with self.session() as session:
725
+ return session.execute(
726
+ select(Room)
727
+ .where(Room.user_account_id == user_pk)
728
+ .where(Room.legacy_id == legacy_id)
729
+ ).scalar()
730
+
731
+ def update(self, room: "LegacyMUC"):
732
+ from slidge.contact import LegacyContact
733
+
734
+ with self.session() as session:
735
+ if room.subject_setter is None:
736
+ subject_setter_id = None
737
+ elif isinstance(room.subject_setter, str):
738
+ subject_setter_id = None
739
+ elif isinstance(room.subject_setter, LegacyContact):
740
+ subject_setter_id = None
741
+ elif room.subject_setter.is_system:
742
+ subject_setter_id = None
743
+ else:
744
+ subject_setter_id = room.subject_setter.pk
745
+
746
+ session.execute(
747
+ update(Room)
748
+ .where(Room.id == room.pk)
749
+ .values(
750
+ updated=True,
751
+ extra_attributes=room.serialize_extra_attributes(),
752
+ name=room.name,
753
+ description=room.description,
754
+ user_resources=(
755
+ None
756
+ if not room._user_resources
757
+ else json.dumps(list(room._user_resources))
758
+ ),
759
+ muc_type=room.type,
760
+ subject=room.subject,
761
+ subject_date=room.subject_date,
762
+ subject_setter_id=subject_setter_id,
763
+ participants_filled=room._participants_filled,
764
+ n_participants=room._n_participants,
765
+ )
766
+ )
767
+ session.commit()
768
+
769
+ def update_subject_date(
770
+ self, room_pk: int, subject_date: Optional[datetime]
771
+ ) -> None:
772
+ with self.session() as session:
773
+ session.execute(
774
+ update(Room).where(Room.id == room_pk).values(subject_date=subject_date)
775
+ )
776
+ session.commit()
777
+
778
+ def update_subject(self, room_pk: int, subject: Optional[str]) -> None:
779
+ with self.session() as session:
780
+ session.execute(
781
+ update(Room).where(Room.id == room_pk).values(subject=subject)
782
+ )
783
+ session.commit()
784
+
785
+ def update_description(self, room_pk: int, desc: Optional[str]) -> None:
786
+ with self.session() as session:
787
+ session.execute(
788
+ update(Room).where(Room.id == room_pk).values(description=desc)
789
+ )
790
+ session.commit()
791
+
792
+ def update_n_participants(self, room_pk: int, n: Optional[int]) -> None:
793
+ with self.session() as session:
794
+ session.execute(
795
+ update(Room).where(Room.id == room_pk).values(n_participants=n)
796
+ )
797
+ session.commit()
798
+
799
+ def delete(self, room_pk: int) -> None:
800
+ with self.session() as session:
801
+ session.execute(delete(Room).where(Room.id == room_pk))
802
+ session.execute(delete(Participant).where(Participant.room_id == room_pk))
803
+ session.commit()
804
+
805
+ def set_resource(self, room_pk: int, resources: set[str]) -> None:
806
+ with self.session() as session:
807
+ session.execute(
808
+ update(Room)
809
+ .where(Room.id == room_pk)
810
+ .values(
811
+ user_resources=(
812
+ None if not resources else json.dumps(list(resources))
813
+ )
814
+ )
815
+ )
816
+ session.commit()
817
+
818
+ def nickname_is_available(self, room_pk: int, nickname: str) -> bool:
819
+ with self.session() as session:
820
+ return (
821
+ session.execute(
822
+ select(Participant)
823
+ .where(Participant.room_id == room_pk)
824
+ .where(Participant.nickname == nickname)
825
+ ).scalar()
826
+ is None
827
+ )
828
+
829
+ def get_all(self, user_pk: int) -> Iterator[Room]:
830
+ with self.session() as session:
831
+ yield from session.execute(
832
+ select(Room).where(Room.user_account_id == user_pk)
833
+ ).scalars()
834
+
835
+
836
+ class ParticipantStore(EngineMixin):
837
+ def __init__(self, *a, **kw):
838
+ super().__init__(*a, **kw)
839
+ with self.session() as session:
840
+ session.execute(delete(Participant))
841
+ session.execute(delete(Hat))
842
+ session.commit()
843
+
844
+ def add(self, room_pk: int, nickname: str) -> int:
845
+ with self.session() as session:
846
+ existing = session.execute(
847
+ select(Participant.id)
848
+ .where(Participant.room_id == room_pk)
849
+ .where(Participant.nickname == nickname)
850
+ ).scalar()
851
+ if existing is not None:
852
+ return existing
853
+ participant = Participant(room_id=room_pk, nickname=nickname)
854
+ session.add(participant)
855
+ session.commit()
856
+ return participant.id
857
+
858
+ def get_by_nickname(self, room_pk: int, nickname: str) -> Optional[Participant]:
859
+ with self.session() as session:
860
+ return session.execute(
861
+ select(Participant)
862
+ .where(Participant.room_id == room_pk)
863
+ .where(Participant.nickname == nickname)
864
+ ).scalar()
865
+
866
+ def get_by_resource(self, room_pk: int, resource: str) -> Optional[Participant]:
867
+ with self.session() as session:
868
+ return session.execute(
869
+ select(Participant)
870
+ .where(Participant.room_id == room_pk)
871
+ .where(Participant.resource == resource)
872
+ ).scalar()
873
+
874
+ def get_by_contact(self, room_pk: int, contact_pk: int) -> Optional[Participant]:
875
+ with self.session() as session:
876
+ return session.execute(
877
+ select(Participant)
878
+ .where(Participant.room_id == room_pk)
879
+ .where(Participant.contact_id == contact_pk)
880
+ ).scalar()
881
+
882
+ def get_all(self, room_pk: int, user_included=True) -> Iterator[Participant]:
883
+ with self.session() as session:
884
+ q = select(Participant).where(Participant.room_id == room_pk)
885
+ if not user_included:
886
+ q = q.where(~Participant.is_user)
887
+ yield from session.execute(q).scalars()
888
+
889
+ def get_for_contact(self, contact_pk: int) -> Iterator[Participant]:
890
+ with self.session() as session:
891
+ yield from session.execute(
892
+ select(Participant).where(Participant.contact_id == contact_pk)
893
+ ).scalars()
894
+
895
+ def update(self, participant: "LegacyParticipant") -> None:
896
+ with self.session() as session:
897
+ session.execute(
898
+ update(Participant)
899
+ .where(Participant.id == participant.pk)
900
+ .values(
901
+ resource=participant.jid.resource,
902
+ affiliation=participant.affiliation,
903
+ role=participant.role,
904
+ presence_sent=participant._presence_sent, # type:ignore
905
+ # hats=[self.add_hat(h.uri, h.title) for h in participant._hats],
906
+ is_user=participant.is_user,
907
+ contact_id=(
908
+ None
909
+ if participant.contact is None
910
+ else participant.contact.contact_pk
911
+ ),
912
+ )
913
+ )
914
+ session.commit()
915
+
916
+ def add_hat(self, uri: str, title: str) -> Hat:
917
+ with self.session() as session:
918
+ existing = session.execute(
919
+ select(Hat).where(Hat.uri == uri).where(Hat.title == title)
920
+ ).scalar()
921
+ if existing is not None:
922
+ return existing
923
+ hat = Hat(uri=uri, title=title)
924
+ session.add(hat)
925
+ session.commit()
926
+ return hat
927
+
928
+ def set_presence_sent(self, participant_pk: int) -> None:
929
+ with self.session() as session:
930
+ session.execute(
931
+ update(Participant)
932
+ .where(Participant.id == participant_pk)
933
+ .values(presence_sent=True)
934
+ )
935
+ session.commit()
936
+
937
+ def set_affiliation(self, participant_pk: int, affiliation: MucAffiliation) -> None:
938
+ with self.session() as session:
939
+ session.execute(
940
+ update(Participant)
941
+ .where(Participant.id == participant_pk)
942
+ .values(affiliation=affiliation)
943
+ )
944
+ session.commit()
945
+
946
+ def set_role(self, participant_pk: int, role: MucRole) -> None:
947
+ with self.session() as session:
948
+ session.execute(
949
+ update(Participant)
950
+ .where(Participant.id == participant_pk)
951
+ .values(role=role)
952
+ )
953
+ session.commit()
954
+
955
+ def set_hats(self, participant_pk: int, hats: list[HatTuple]) -> None:
956
+ with self.session() as session:
957
+ part = session.execute(
958
+ select(Participant).where(Participant.id == participant_pk)
959
+ ).scalar()
960
+ if part is None:
961
+ raise ValueError
962
+ part.hats.clear()
963
+ for h in hats:
964
+ hat = self.add_hat(*h)
965
+ if hat in part.hats:
966
+ continue
967
+ part.hats.append(hat)
968
+ session.commit()
969
+
970
+ def delete(self, participant_pk: int) -> None:
971
+ with self.session() as session:
972
+ session.execute(delete(Participant).where(Participant.id == participant_pk))
973
+
974
+
975
+ log = logging.getLogger(__name__)
976
+ _session: Optional[Session] = None