slidge 0.1.3__py3-none-any.whl → 0.2.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. slidge/__init__.py +3 -5
  2. slidge/__main__.py +2 -196
  3. slidge/__version__.py +5 -0
  4. slidge/command/adhoc.py +8 -1
  5. slidge/command/admin.py +6 -7
  6. slidge/command/base.py +1 -2
  7. slidge/command/register.py +32 -16
  8. slidge/command/user.py +85 -6
  9. slidge/contact/contact.py +165 -49
  10. slidge/contact/roster.py +122 -47
  11. slidge/core/config.py +14 -11
  12. slidge/core/gateway/base.py +148 -36
  13. slidge/core/gateway/caps.py +7 -5
  14. slidge/core/gateway/disco.py +2 -4
  15. slidge/core/gateway/mam.py +1 -4
  16. slidge/core/gateway/muc_admin.py +1 -1
  17. slidge/core/gateway/ping.py +2 -3
  18. slidge/core/gateway/presence.py +1 -1
  19. slidge/core/gateway/registration.py +32 -21
  20. slidge/core/gateway/search.py +3 -5
  21. slidge/core/gateway/session_dispatcher.py +120 -57
  22. slidge/core/gateway/vcard_temp.py +7 -5
  23. slidge/core/mixins/__init__.py +11 -1
  24. slidge/core/mixins/attachment.py +32 -14
  25. slidge/core/mixins/avatar.py +90 -25
  26. slidge/core/mixins/base.py +8 -2
  27. slidge/core/mixins/db.py +18 -0
  28. slidge/core/mixins/disco.py +0 -10
  29. slidge/core/mixins/message.py +18 -8
  30. slidge/core/mixins/message_maker.py +17 -9
  31. slidge/core/mixins/presence.py +17 -4
  32. slidge/core/pubsub.py +54 -220
  33. slidge/core/session.py +69 -34
  34. slidge/db/__init__.py +4 -0
  35. slidge/db/alembic/env.py +64 -0
  36. slidge/db/alembic/script.py.mako +26 -0
  37. slidge/db/alembic/versions/09f27f098baa_add_missing_attributes_in_room.py +36 -0
  38. slidge/db/alembic/versions/2461390c0af2_store_contacts_caps_verstring_in_db.py +36 -0
  39. slidge/db/alembic/versions/29f5280c61aa_store_subject_setter_in_room.py +37 -0
  40. slidge/db/alembic/versions/2b1f45ab7379_store_room_subject_setter_by_nickname.py +41 -0
  41. slidge/db/alembic/versions/82a4af84b679_add_muc_history_filled.py +48 -0
  42. slidge/db/alembic/versions/8d2ced764698_rely_on_db_to_store_contacts_rooms_and_.py +133 -0
  43. slidge/db/alembic/versions/aa9d82a7f6ef_db_creation.py +85 -0
  44. slidge/db/alembic/versions/b33993e87db3_move_everything_to_persistent_db.py +214 -0
  45. slidge/db/alembic/versions/b64b1a793483_add_source_and_legacy_id_for_archived_.py +48 -0
  46. slidge/db/alembic/versions/c4a8ec35a0e8_per_room_user_nick.py +34 -0
  47. slidge/db/alembic/versions/e91195719c2c_store_users_avatars_persistently.py +26 -0
  48. slidge/db/avatar.py +235 -0
  49. slidge/db/meta.py +65 -0
  50. slidge/db/models.py +375 -0
  51. slidge/db/store.py +1078 -0
  52. slidge/group/archive.py +58 -14
  53. slidge/group/bookmarks.py +72 -57
  54. slidge/group/participant.py +87 -28
  55. slidge/group/room.py +369 -211
  56. slidge/main.py +201 -0
  57. slidge/migration.py +30 -0
  58. slidge/slixfix/__init__.py +35 -2
  59. slidge/slixfix/roster.py +11 -4
  60. slidge/slixfix/xep_0292/vcard4.py +3 -0
  61. slidge/util/archive_msg.py +2 -1
  62. slidge/util/db.py +1 -47
  63. slidge/util/test.py +71 -4
  64. slidge/util/types.py +29 -4
  65. slidge/util/util.py +22 -0
  66. {slidge-0.1.3.dist-info → slidge-0.2.0a1.dist-info}/METADATA +4 -4
  67. slidge-0.2.0a1.dist-info/RECORD +114 -0
  68. slidge/core/cache.py +0 -183
  69. slidge/util/schema.sql +0 -126
  70. slidge/util/sql.py +0 -508
  71. slidge-0.1.3.dist-info/RECORD +0 -96
  72. {slidge-0.1.3.dist-info → slidge-0.2.0a1.dist-info}/LICENSE +0 -0
  73. {slidge-0.1.3.dist-info → slidge-0.2.0a1.dist-info}/WHEEL +0 -0
  74. {slidge-0.1.3.dist-info → slidge-0.2.0a1.dist-info}/entry_points.txt +0 -0
slidge/db/store.py ADDED
@@ -0,0 +1,1078 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ from contextlib import contextmanager
6
+ from datetime import datetime, timedelta, timezone
7
+ from typing import TYPE_CHECKING, Collection, Iterator, Optional, Type
8
+
9
+ from slixmpp import JID, Iq, Message, Presence
10
+ from slixmpp.exceptions import XMPPError
11
+ from sqlalchemy import Engine, delete, select, update
12
+ from sqlalchemy.orm import Session, attributes
13
+ from sqlalchemy.sql.functions import count
14
+
15
+ from ..util.archive_msg import HistoryMessage
16
+ from ..util.types import URL, CachedPresence
17
+ from ..util.types import Hat as HatTuple
18
+ from ..util.types import MamMetadata, MucAffiliation, MucRole
19
+ from .meta import Base
20
+ from .models import (
21
+ ArchivedMessage,
22
+ ArchivedMessageSource,
23
+ Attachment,
24
+ Avatar,
25
+ Contact,
26
+ ContactSent,
27
+ GatewayUser,
28
+ Hat,
29
+ LegacyIdsMulti,
30
+ Participant,
31
+ Room,
32
+ XmppIdsMulti,
33
+ XmppToLegacyEnum,
34
+ XmppToLegacyIds,
35
+ )
36
+
37
+ if TYPE_CHECKING:
38
+ from ..contact.contact import LegacyContact
39
+ from ..group.participant import LegacyParticipant
40
+ from ..group.room import LegacyMUC
41
+
42
+
43
+ class EngineMixin:
44
+ def __init__(self, engine: Engine):
45
+ self._engine = engine
46
+
47
+ @contextmanager
48
+ def session(self, **session_kwargs) -> Iterator[Session]:
49
+ global _session
50
+ if _session is not None:
51
+ yield _session
52
+ return
53
+ with Session(self._engine, **session_kwargs) as session:
54
+ _session = session
55
+ try:
56
+ yield session
57
+ finally:
58
+ _session = None
59
+
60
+
61
+ class UpdatedMixin(EngineMixin):
62
+ model: Type[Base] = NotImplemented
63
+
64
+ def __init__(self, *a, **kw):
65
+ super().__init__(*a, **kw)
66
+ with self.session() as session:
67
+ session.execute(
68
+ delete(self.model).where(~self.model.updated) # type:ignore
69
+ )
70
+ session.execute(update(self.model).values(updated=False)) # type:ignore
71
+ session.commit()
72
+
73
+ def get_by_pk(self, pk: int) -> Optional[Base]:
74
+ with self.session() as session:
75
+ return session.execute(
76
+ select(self.model).where(self.model.id == pk) # type:ignore
77
+ ).scalar()
78
+
79
+
80
+ class SlidgeStore(EngineMixin):
81
+ def __init__(self, engine: Engine):
82
+ super().__init__(engine)
83
+ self.users = UserStore(engine)
84
+ self.avatars = AvatarStore(engine)
85
+ self.contacts = ContactStore(engine)
86
+ self.mam = MAMStore(engine)
87
+ self.multi = MultiStore(engine)
88
+ self.attachments = AttachmentStore(engine)
89
+ self.rooms = RoomStore(engine)
90
+ self.sent = SentStore(engine)
91
+ self.participants = ParticipantStore(engine)
92
+
93
+
94
+ class UserStore(EngineMixin):
95
+ def new(self, jid: JID, legacy_module_data: dict) -> GatewayUser:
96
+ if jid.resource:
97
+ jid = JID(jid.bare)
98
+ with self.session(expire_on_commit=False) as session:
99
+ user = session.execute(
100
+ select(GatewayUser).where(GatewayUser.jid == jid)
101
+ ).scalar()
102
+ if user is not None:
103
+ return user
104
+ user = GatewayUser(jid=jid, legacy_module_data=legacy_module_data)
105
+ session.add(user)
106
+ session.commit()
107
+ return user
108
+
109
+ def update(self, user: GatewayUser):
110
+ # https://github.com/sqlalchemy/sqlalchemy/discussions/6473
111
+ attributes.flag_modified(user, "legacy_module_data")
112
+ attributes.flag_modified(user, "preferences")
113
+ with self.session() as session:
114
+ session.add(user)
115
+ session.commit()
116
+
117
+ def get_all(self) -> Iterator[GatewayUser]:
118
+ with self.session() as session:
119
+ yield from session.execute(select(GatewayUser)).scalars()
120
+
121
+ def get(self, jid: JID) -> Optional[GatewayUser]:
122
+ with self.session() as session:
123
+ return session.execute(
124
+ select(GatewayUser).where(GatewayUser.jid == jid.bare)
125
+ ).scalar()
126
+
127
+ def get_by_stanza(self, stanza: Iq | Message | Presence) -> Optional[GatewayUser]:
128
+ return self.get(stanza.get_from())
129
+
130
+ def delete(self, jid: JID) -> None:
131
+ with self.session() as session:
132
+ session.delete(self.get(jid))
133
+ session.commit()
134
+
135
+
136
+ class AvatarStore(EngineMixin):
137
+ def get_by_url(self, url: URL | str) -> Optional[Avatar]:
138
+ with self.session() as session:
139
+ return session.execute(select(Avatar).where(Avatar.url == url)).scalar()
140
+
141
+ def get_by_legacy_id(self, legacy_id: str) -> Optional[Avatar]:
142
+ with self.session() as session:
143
+ return session.execute(
144
+ select(Avatar).where(Avatar.legacy_id == legacy_id)
145
+ ).scalar()
146
+
147
+ def get_by_pk(self, pk: int) -> Optional[Avatar]:
148
+ with self.session() as session:
149
+ return session.execute(select(Avatar).where(Avatar.id == pk)).scalar()
150
+
151
+ def delete_by_pk(self, pk: int):
152
+ with self.session() as session:
153
+ session.execute(delete(Avatar).where(Avatar.id == pk))
154
+ session.commit()
155
+
156
+ def get_all(self) -> Iterator[Avatar]:
157
+ with self.session() as session:
158
+ yield from session.execute(select(Avatar)).scalars()
159
+
160
+
161
+ class SentStore(EngineMixin):
162
+ def set_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
163
+ with self.session() as session:
164
+ msg = XmppToLegacyIds(
165
+ user_account_id=user_pk,
166
+ legacy_id=legacy_id,
167
+ xmpp_id=xmpp_id,
168
+ type=XmppToLegacyEnum.DM,
169
+ )
170
+ session.add(msg)
171
+ session.commit()
172
+
173
+ def get_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]:
174
+ with self.session() as session:
175
+ return session.execute(
176
+ select(XmppToLegacyIds.xmpp_id)
177
+ .where(XmppToLegacyIds.user_account_id == user_pk)
178
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
179
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM)
180
+ ).scalar()
181
+
182
+ def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
183
+ with self.session() as session:
184
+ return session.execute(
185
+ select(XmppToLegacyIds.legacy_id)
186
+ .where(XmppToLegacyIds.user_account_id == user_pk)
187
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
188
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM)
189
+ ).scalar()
190
+
191
+ def set_group_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
192
+ with self.session() as session:
193
+ msg = XmppToLegacyIds(
194
+ user_account_id=user_pk,
195
+ legacy_id=legacy_id,
196
+ xmpp_id=xmpp_id,
197
+ type=XmppToLegacyEnum.GROUP_CHAT,
198
+ )
199
+ session.add(msg)
200
+ session.commit()
201
+
202
+ def get_group_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]:
203
+ with self.session() as session:
204
+ return session.execute(
205
+ select(XmppToLegacyIds.xmpp_id)
206
+ .where(XmppToLegacyIds.user_account_id == user_pk)
207
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
208
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT)
209
+ ).scalar()
210
+
211
+ def get_group_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
212
+ with self.session() as session:
213
+ return session.execute(
214
+ select(XmppToLegacyIds.legacy_id)
215
+ .where(XmppToLegacyIds.user_account_id == user_pk)
216
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
217
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT)
218
+ ).scalar()
219
+
220
+ def set_thread(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
221
+ with self.session() as session:
222
+ msg = XmppToLegacyIds(
223
+ user_account_id=user_pk,
224
+ legacy_id=legacy_id,
225
+ xmpp_id=xmpp_id,
226
+ type=XmppToLegacyEnum.THREAD,
227
+ )
228
+ session.add(msg)
229
+ session.commit()
230
+
231
+ def get_legacy_thread(self, user_pk: int, xmpp_id: str) -> Optional[str]:
232
+ with self.session() as session:
233
+ return session.execute(
234
+ select(XmppToLegacyIds.legacy_id)
235
+ .where(XmppToLegacyIds.user_account_id == user_pk)
236
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
237
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.THREAD)
238
+ ).scalar()
239
+
240
+ def get_xmpp_thread(self, user_pk: int, legacy_id: str) -> Optional[str]:
241
+ with self.session() as session:
242
+ return session.execute(
243
+ select(XmppToLegacyIds.xmpp_id)
244
+ .where(XmppToLegacyIds.user_account_id == user_pk)
245
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
246
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.THREAD)
247
+ ).scalar()
248
+
249
+ def was_sent_by_user(self, user_pk: int, legacy_id: str) -> bool:
250
+ with self.session() as session:
251
+ return (
252
+ session.execute(
253
+ select(XmppToLegacyIds.legacy_id)
254
+ .where(XmppToLegacyIds.user_account_id == user_pk)
255
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
256
+ ).scalar()
257
+ is not None
258
+ )
259
+
260
+
261
+ class ContactStore(UpdatedMixin):
262
+ model = Contact
263
+
264
+ def __init__(self, *a, **k):
265
+ super().__init__(*a, **k)
266
+ with self.session() as session:
267
+ session.execute(update(Contact).values(cached_presence=False))
268
+ session.commit()
269
+
270
+ def get_all(self, user_pk: int) -> Iterator[Contact]:
271
+ with self.session() as session:
272
+ yield from session.execute(
273
+ select(Contact).where(Contact.user_account_id == user_pk)
274
+ ).scalars()
275
+
276
+ def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Contact]:
277
+ with self.session() as session:
278
+ return session.execute(
279
+ select(Contact)
280
+ .where(Contact.jid == jid.bare)
281
+ .where(Contact.user_account_id == user_pk)
282
+ ).scalar()
283
+
284
+ def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Contact]:
285
+ with self.session() as session:
286
+ return session.execute(
287
+ select(Contact)
288
+ .where(Contact.legacy_id == legacy_id)
289
+ .where(Contact.user_account_id == user_pk)
290
+ ).scalar()
291
+
292
+ def update_nick(self, contact_pk: int, nick: Optional[str]) -> None:
293
+ with self.session() as session:
294
+ session.execute(
295
+ update(Contact).where(Contact.id == contact_pk).values(nick=nick)
296
+ )
297
+ session.commit()
298
+
299
+ def get_presence(self, contact_pk: int) -> Optional[CachedPresence]:
300
+ with self.session() as session:
301
+ presence = session.execute(
302
+ select(
303
+ Contact.last_seen,
304
+ Contact.ptype,
305
+ Contact.pstatus,
306
+ Contact.pshow,
307
+ Contact.cached_presence,
308
+ ).where(Contact.id == contact_pk)
309
+ ).first()
310
+ if presence is None or not presence[-1]:
311
+ return None
312
+ return CachedPresence(*presence[:-1])
313
+
314
+ def set_presence(self, contact_pk: int, presence: CachedPresence) -> None:
315
+ with self.session() as session:
316
+ session.execute(
317
+ update(Contact)
318
+ .where(Contact.id == contact_pk)
319
+ .values(**presence._asdict(), cached_presence=True)
320
+ )
321
+ session.commit()
322
+
323
+ def reset_presence(self, contact_pk: int):
324
+ with self.session() as session:
325
+ session.execute(
326
+ update(Contact)
327
+ .where(Contact.id == contact_pk)
328
+ .values(
329
+ last_seen=None,
330
+ ptype=None,
331
+ pstatus=None,
332
+ pshow=None,
333
+ cached_presence=False,
334
+ )
335
+ )
336
+ session.commit()
337
+
338
+ def set_avatar(self, contact_pk: int, avatar_pk: Optional[int]):
339
+ with self.session() as session:
340
+ session.execute(
341
+ update(Contact)
342
+ .where(Contact.id == contact_pk)
343
+ .values(avatar_id=avatar_pk)
344
+ )
345
+ session.commit()
346
+
347
+ def get_avatar_legacy_id(self, contact_pk: int) -> Optional[str]:
348
+ with self.session() as session:
349
+ contact = session.execute(
350
+ select(Contact).where(Contact.id == contact_pk)
351
+ ).scalar()
352
+ if contact is None or contact.avatar is None:
353
+ return None
354
+ return contact.avatar.legacy_id
355
+
356
+ def update(self, contact: "LegacyContact", commit=True) -> int:
357
+ with self.session() as session:
358
+ if contact.contact_pk is None:
359
+ if contact.cached_presence is not None:
360
+ presence_kwargs = contact.cached_presence._asdict()
361
+ presence_kwargs["cached_presence"] = True
362
+ else:
363
+ presence_kwargs = {}
364
+ row = Contact(
365
+ jid=contact.jid.bare,
366
+ legacy_id=str(contact.legacy_id),
367
+ user_account_id=contact.user_pk,
368
+ **presence_kwargs,
369
+ )
370
+ else:
371
+ row = (
372
+ session.query(Contact)
373
+ .filter(Contact.id == contact.contact_pk)
374
+ .one()
375
+ )
376
+ row.nick = contact.name
377
+ row.is_friend = contact.is_friend
378
+ row.added_to_roster = contact.added_to_roster
379
+ row.updated = True
380
+ row.extra_attributes = contact.serialize_extra_attributes()
381
+ row.caps_ver = contact._caps_ver
382
+ session.add(row)
383
+ if commit:
384
+ session.commit()
385
+ return row.id
386
+
387
+ def add_to_sent(self, contact_pk: int, msg_id: str) -> None:
388
+ with self.session() as session:
389
+ new = ContactSent(contact_id=contact_pk, msg_id=msg_id)
390
+ session.add(new)
391
+ session.commit()
392
+
393
+ def pop_sent_up_to(self, contact_pk: int, msg_id: str) -> list[str]:
394
+ result = []
395
+ to_del = []
396
+ with self.session() as session:
397
+ for row in session.execute(
398
+ select(ContactSent)
399
+ .where(ContactSent.contact_id == contact_pk)
400
+ .order_by(ContactSent.id)
401
+ ).scalars():
402
+ to_del.append(row.id)
403
+ result.append(row.msg_id)
404
+ if row.msg_id == msg_id:
405
+ break
406
+ for row_id in to_del:
407
+ session.execute(delete(ContactSent).where(ContactSent.id == row_id))
408
+ return result
409
+
410
+ def set_friend(self, contact_pk: int, is_friend: bool) -> None:
411
+ with self.session() as session:
412
+ session.execute(
413
+ update(Contact)
414
+ .where(Contact.id == contact_pk)
415
+ .values(is_friend=is_friend)
416
+ )
417
+ session.commit()
418
+
419
+ def set_added_to_roster(self, contact_pk: int, value: bool) -> None:
420
+ with self.session() as session:
421
+ session.execute(
422
+ update(Contact)
423
+ .where(Contact.id == contact_pk)
424
+ .values(added_to_roster=value)
425
+ )
426
+ session.commit()
427
+
428
+ def delete(self, contact_pk: int) -> None:
429
+ with self.session() as session:
430
+ session.execute(delete(Contact).where(Contact.id == contact_pk))
431
+ session.commit()
432
+
433
+
434
+ class MAMStore(EngineMixin):
435
+ def __init__(self, *a, **kw):
436
+ super().__init__(*a, **kw)
437
+ with self.session() as session:
438
+ session.execute(
439
+ update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL)
440
+ )
441
+ session.commit()
442
+
443
+ def nuke_older_than(self, days: int) -> None:
444
+ with self.session() as session:
445
+ session.execute(
446
+ delete(ArchivedMessage).where(
447
+ ArchivedMessage.timestamp < datetime.now() - timedelta(days=days)
448
+ )
449
+ )
450
+ session.commit()
451
+
452
+ def add_message(
453
+ self,
454
+ room_pk: int,
455
+ message: HistoryMessage,
456
+ archive_only: bool,
457
+ legacy_msg_id: str | None,
458
+ ) -> None:
459
+ with self.session() as session:
460
+ source = (
461
+ ArchivedMessageSource.BACKFILL
462
+ if archive_only
463
+ else ArchivedMessageSource.LIVE
464
+ )
465
+ existing = session.execute(
466
+ select(ArchivedMessage)
467
+ .where(ArchivedMessage.room_id == room_pk)
468
+ .where(ArchivedMessage.stanza_id == message.id)
469
+ ).scalar()
470
+ if existing is not None:
471
+ log.debug("Updating message %s in room %s", message.id, room_pk)
472
+ existing.timestamp = message.when
473
+ existing.stanza = str(message.stanza)
474
+ existing.author_jid = message.stanza.get_from()
475
+ existing.source = source
476
+ existing.legacy_id = legacy_msg_id
477
+ session.add(existing)
478
+ session.commit()
479
+ return
480
+ mam_msg = ArchivedMessage(
481
+ stanza_id=message.id,
482
+ timestamp=message.when,
483
+ stanza=str(message.stanza),
484
+ author_jid=message.stanza.get_from(),
485
+ room_id=room_pk,
486
+ source=source,
487
+ legacy_id=legacy_msg_id,
488
+ )
489
+ session.add(mam_msg)
490
+ session.commit()
491
+
492
+ def get_messages(
493
+ self,
494
+ room_pk: int,
495
+ start_date: Optional[datetime] = None,
496
+ end_date: Optional[datetime] = None,
497
+ before_id: Optional[str] = None,
498
+ after_id: Optional[str] = None,
499
+ ids: Collection[str] = (),
500
+ last_page_n: Optional[int] = None,
501
+ sender: Optional[str] = None,
502
+ flip=False,
503
+ ) -> Iterator[HistoryMessage]:
504
+
505
+ with self.session() as session:
506
+ q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
507
+ if start_date is not None:
508
+ q = q.where(ArchivedMessage.timestamp >= start_date)
509
+ if end_date is not None:
510
+ q = q.where(ArchivedMessage.timestamp <= end_date)
511
+ if before_id is not None:
512
+ stamp = session.execute(
513
+ select(ArchivedMessage.timestamp).where(
514
+ ArchivedMessage.stanza_id == before_id
515
+ )
516
+ ).scalar()
517
+ if stamp is None:
518
+ raise XMPPError(
519
+ "item-not-found",
520
+ f"Message {before_id} not found",
521
+ )
522
+ q = q.where(ArchivedMessage.timestamp < stamp)
523
+ if after_id is not None:
524
+ stamp = session.execute(
525
+ select(ArchivedMessage.timestamp).where(
526
+ ArchivedMessage.stanza_id == after_id
527
+ )
528
+ ).scalar()
529
+ if stamp is None:
530
+ raise XMPPError(
531
+ "item-not-found",
532
+ f"Message {after_id} not found",
533
+ )
534
+ q = q.where(ArchivedMessage.timestamp > stamp)
535
+ if ids:
536
+ q = q.filter(ArchivedMessage.stanza_id.in_(ids))
537
+ if sender is not None:
538
+ q = q.where(ArchivedMessage.author_jid == sender)
539
+ if flip:
540
+ q = q.order_by(ArchivedMessage.timestamp.desc())
541
+ else:
542
+ q = q.order_by(ArchivedMessage.timestamp.asc())
543
+ msgs = list(session.execute(q).scalars())
544
+ if ids and len(msgs) != len(ids):
545
+ raise XMPPError(
546
+ "item-not-found",
547
+ "One of the requested messages IDs could not be found "
548
+ "with the given constraints.",
549
+ )
550
+ if last_page_n is not None:
551
+ if flip:
552
+ msgs = msgs[:last_page_n]
553
+ else:
554
+ msgs = msgs[-last_page_n:]
555
+ for h in msgs:
556
+ yield HistoryMessage(
557
+ stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=timezone.utc)
558
+ )
559
+
560
+ def get_first(self, room_pk: int, with_legacy_id=False) -> ArchivedMessage | None:
561
+ with self.session() as session:
562
+ q = (
563
+ select(ArchivedMessage)
564
+ .where(ArchivedMessage.room_id == room_pk)
565
+ .order_by(ArchivedMessage.timestamp.asc())
566
+ )
567
+ if with_legacy_id:
568
+ q = q.filter(ArchivedMessage.legacy_id.isnot(None))
569
+ return session.execute(q).scalar()
570
+
571
+ def get_last(
572
+ self, room_pk: int, source: ArchivedMessageSource | None = None
573
+ ) -> ArchivedMessage | None:
574
+ with self.session() as session:
575
+ q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
576
+
577
+ if source is not None:
578
+ q = q.where(ArchivedMessage.source == source)
579
+
580
+ return session.execute(
581
+ q.order_by(ArchivedMessage.timestamp.desc())
582
+ ).scalar()
583
+
584
+ def get_first_and_last(self, room_pk: int) -> list[MamMetadata]:
585
+ r = []
586
+ with self.session():
587
+ first = self.get_first(room_pk)
588
+ if first is not None:
589
+ r.append(MamMetadata(first.stanza_id, first.timestamp))
590
+ last = self.get_last(room_pk)
591
+ if last is not None:
592
+ r.append(MamMetadata(last.stanza_id, last.timestamp))
593
+ return r
594
+
595
+ def get_most_recent_with_legacy_id(
596
+ self, room_pk: int, source: ArchivedMessageSource | None = None
597
+ ) -> ArchivedMessage | None:
598
+ with self.session() as session:
599
+ q = (
600
+ select(ArchivedMessage)
601
+ .where(ArchivedMessage.room_id == room_pk)
602
+ .where(ArchivedMessage.legacy_id.isnot(None))
603
+ )
604
+ if source is not None:
605
+ q = q.where(ArchivedMessage.source == source)
606
+ return session.execute(
607
+ q.order_by(ArchivedMessage.timestamp.desc())
608
+ ).scalar()
609
+
610
+ def get_least_recent_with_legacy_id_after(
611
+ self, room_pk: int, after_id: str, source=ArchivedMessageSource.LIVE
612
+ ) -> ArchivedMessage | None:
613
+ with self.session() as session:
614
+ after_timestamp = (
615
+ session.query(ArchivedMessage.timestamp)
616
+ .filter(ArchivedMessage.legacy_id == after_id)
617
+ .scalar()
618
+ )
619
+ q = (
620
+ select(ArchivedMessage)
621
+ .where(ArchivedMessage.room_id == room_pk)
622
+ .where(ArchivedMessage.legacy_id.isnot(None))
623
+ .where(ArchivedMessage.source == source)
624
+ .where(ArchivedMessage.timestamp > after_timestamp)
625
+ )
626
+ return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar()
627
+
628
+ def get_by_legacy_id(self, room_pk: int, legacy_id: str) -> ArchivedMessage | None:
629
+ with self.session() as session:
630
+ return (
631
+ session.query(ArchivedMessage)
632
+ .filter(ArchivedMessage.room_id == room_pk)
633
+ .filter(ArchivedMessage.legacy_id == legacy_id)
634
+ .first()
635
+ )
636
+
637
+
638
+ class MultiStore(EngineMixin):
639
+ def get_xmpp_ids(self, user_pk: int, xmpp_id: str) -> list[str]:
640
+ with self.session() as session:
641
+ multi = session.execute(
642
+ select(XmppIdsMulti)
643
+ .where(XmppIdsMulti.xmpp_id == xmpp_id)
644
+ .where(XmppIdsMulti.user_account_id == user_pk)
645
+ ).scalar()
646
+ if multi is None:
647
+ return []
648
+ return [m.xmpp_id for m in multi.legacy_ids_multi.xmpp_ids]
649
+
650
+ def set_xmpp_ids(
651
+ self, user_pk: int, legacy_msg_id: str, xmpp_ids: list[str], fail=False
652
+ ) -> None:
653
+ with self.session() as session:
654
+ existing = session.execute(
655
+ select(LegacyIdsMulti)
656
+ .where(LegacyIdsMulti.user_account_id == user_pk)
657
+ .where(LegacyIdsMulti.legacy_id == legacy_msg_id)
658
+ ).scalar()
659
+ if existing is not None:
660
+ if fail:
661
+ raise
662
+ log.warning("Resetting multi for %s", legacy_msg_id)
663
+ session.execute(
664
+ delete(LegacyIdsMulti)
665
+ .where(LegacyIdsMulti.user_account_id == user_pk)
666
+ .where(LegacyIdsMulti.legacy_id == legacy_msg_id)
667
+ )
668
+ for i in xmpp_ids:
669
+ session.execute(
670
+ delete(XmppIdsMulti)
671
+ .where(XmppIdsMulti.user_account_id == user_pk)
672
+ .where(XmppIdsMulti.xmpp_id == i)
673
+ )
674
+ session.commit()
675
+ self.set_xmpp_ids(user_pk, legacy_msg_id, xmpp_ids, True)
676
+ return
677
+
678
+ row = LegacyIdsMulti(
679
+ user_account_id=user_pk,
680
+ legacy_id=legacy_msg_id,
681
+ xmpp_ids=[
682
+ XmppIdsMulti(user_account_id=user_pk, xmpp_id=i)
683
+ for i in xmpp_ids
684
+ if i
685
+ ],
686
+ )
687
+ session.add(row)
688
+ session.commit()
689
+
690
+ def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
691
+ with self.session() as session:
692
+ multi = session.execute(
693
+ select(XmppIdsMulti)
694
+ .where(XmppIdsMulti.xmpp_id == xmpp_id)
695
+ .where(XmppIdsMulti.user_account_id == user_pk)
696
+ ).scalar()
697
+ if multi is None:
698
+ return None
699
+ return multi.legacy_ids_multi.legacy_id
700
+
701
+
702
+ class AttachmentStore(EngineMixin):
703
+ def get_url(self, legacy_file_id: str) -> Optional[str]:
704
+ with self.session() as session:
705
+ return session.execute(
706
+ select(Attachment.url).where(
707
+ Attachment.legacy_file_id == legacy_file_id
708
+ )
709
+ ).scalar()
710
+
711
+ def set_url(self, user_pk: int, legacy_file_id: str, url: str) -> None:
712
+ with self.session() as session:
713
+ att = session.execute(
714
+ select(Attachment)
715
+ .where(Attachment.legacy_file_id == legacy_file_id)
716
+ .where(Attachment.user_account_id == user_pk)
717
+ ).scalar()
718
+ if att is None:
719
+ att = Attachment(
720
+ legacy_file_id=legacy_file_id, url=url, user_account_id=user_pk
721
+ )
722
+ session.add(att)
723
+ else:
724
+ att.url = url
725
+ session.commit()
726
+
727
+ def get_sims(self, url: str) -> Optional[str]:
728
+ with self.session() as session:
729
+ return session.execute(
730
+ select(Attachment.sims).where(Attachment.url == url)
731
+ ).scalar()
732
+
733
+ def set_sims(self, url: str, sims: str) -> None:
734
+ with self.session() as session:
735
+ session.execute(
736
+ update(Attachment).where(Attachment.url == url).values(sims=sims)
737
+ )
738
+ session.commit()
739
+
740
+ def get_sfs(self, url: str) -> Optional[str]:
741
+ with self.session() as session:
742
+ return session.execute(
743
+ select(Attachment.sfs).where(Attachment.url == url)
744
+ ).scalar()
745
+
746
+ def set_sfs(self, url: str, sfs: str) -> None:
747
+ with self.session() as session:
748
+ session.execute(
749
+ update(Attachment).where(Attachment.url == url).values(sfs=sfs)
750
+ )
751
+ session.commit()
752
+
753
+ def remove(self, legacy_file_id: str) -> None:
754
+ with self.session() as session:
755
+ session.execute(
756
+ delete(Attachment).where(Attachment.legacy_file_id == legacy_file_id)
757
+ )
758
+ session.commit()
759
+
760
+
761
+ class RoomStore(UpdatedMixin):
762
+ model = Room
763
+
764
+ def __init__(self, *a, **kw):
765
+ super().__init__(*a, **kw)
766
+ with self.session() as session:
767
+ session.execute(
768
+ update(Room).values(
769
+ subject_setter=None, user_resources=None, history_filled=False
770
+ )
771
+ )
772
+ session.commit()
773
+
774
+ def set_avatar(self, room_pk: int, avatar_pk: int | None) -> None:
775
+ with self.session() as session:
776
+ session.execute(
777
+ update(Room).where(Room.id == room_pk).values(avatar_id=avatar_pk)
778
+ )
779
+ session.commit()
780
+
781
+ def get_avatar_legacy_id(self, room_pk: int) -> Optional[str]:
782
+ with self.session() as session:
783
+ room = session.execute(select(Room).where(Room.id == room_pk)).scalar()
784
+ if room is None or room.avatar is None:
785
+ return None
786
+ return room.avatar.legacy_id
787
+
788
+ def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Room]:
789
+ if jid.resource:
790
+ raise TypeError
791
+ with self.session() as session:
792
+ return session.execute(
793
+ select(Room)
794
+ .where(Room.user_account_id == user_pk)
795
+ .where(Room.jid == jid)
796
+ ).scalar()
797
+
798
+ def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Room]:
799
+ with self.session() as session:
800
+ return session.execute(
801
+ select(Room)
802
+ .where(Room.user_account_id == user_pk)
803
+ .where(Room.legacy_id == legacy_id)
804
+ ).scalar()
805
+
806
+ def update_subject_setter(self, room_pk: int, subject_setter: str | None):
807
+ with self.session() as session:
808
+ session.execute(
809
+ update(Room)
810
+ .where(Room.id == room_pk)
811
+ .values(subject_setter=subject_setter)
812
+ )
813
+ session.commit()
814
+
815
+ def update(self, muc: "LegacyMUC") -> int:
816
+ with self.session() as session:
817
+ if muc.pk is None:
818
+ row = Room(
819
+ jid=muc.jid,
820
+ legacy_id=str(muc.legacy_id),
821
+ user_account_id=muc.user_pk,
822
+ )
823
+ else:
824
+ row = session.query(Room).filter(Room.id == muc.pk).one()
825
+
826
+ row.updated = True
827
+ row.extra_attributes = muc.serialize_extra_attributes()
828
+ row.name = muc.name
829
+ row.description = muc.description
830
+ row.user_resources = (
831
+ None
832
+ if not muc._user_resources
833
+ else json.dumps(list(muc._user_resources))
834
+ )
835
+ row.muc_type = muc.type
836
+ row.subject = muc.subject
837
+ row.subject_date = muc.subject_date
838
+ row.subject_setter = muc.subject_setter
839
+ row.participants_filled = muc._participants_filled
840
+ row.n_participants = muc._n_participants
841
+ row.user_nick = muc.user_nick
842
+ session.add(row)
843
+ session.commit()
844
+ return row.id
845
+
846
+ def update_subject_date(
847
+ self, room_pk: int, subject_date: Optional[datetime]
848
+ ) -> None:
849
+ with self.session() as session:
850
+ session.execute(
851
+ update(Room).where(Room.id == room_pk).values(subject_date=subject_date)
852
+ )
853
+ session.commit()
854
+
855
+ def update_subject(self, room_pk: int, subject: Optional[str]) -> None:
856
+ with self.session() as session:
857
+ session.execute(
858
+ update(Room).where(Room.id == room_pk).values(subject=subject)
859
+ )
860
+ session.commit()
861
+
862
+ def update_description(self, room_pk: int, desc: Optional[str]) -> None:
863
+ with self.session() as session:
864
+ session.execute(
865
+ update(Room).where(Room.id == room_pk).values(description=desc)
866
+ )
867
+ session.commit()
868
+
869
+ def update_name(self, room_pk: int, name: Optional[str]) -> None:
870
+ with self.session() as session:
871
+ session.execute(update(Room).where(Room.id == room_pk).values(name=name))
872
+ session.commit()
873
+
874
+ def update_n_participants(self, room_pk: int, n: Optional[int]) -> None:
875
+ with self.session() as session:
876
+ session.execute(
877
+ update(Room).where(Room.id == room_pk).values(n_participants=n)
878
+ )
879
+ session.commit()
880
+
881
+ def delete(self, room_pk: int) -> None:
882
+ with self.session() as session:
883
+ session.execute(delete(Room).where(Room.id == room_pk))
884
+ session.execute(delete(Participant).where(Participant.room_id == room_pk))
885
+ session.commit()
886
+
887
+ def set_resource(self, room_pk: int, resources: set[str]) -> None:
888
+ with self.session() as session:
889
+ session.execute(
890
+ update(Room)
891
+ .where(Room.id == room_pk)
892
+ .values(
893
+ user_resources=(
894
+ None if not resources else json.dumps(list(resources))
895
+ )
896
+ )
897
+ )
898
+ session.commit()
899
+
900
+ def nickname_is_available(self, room_pk: int, nickname: str) -> bool:
901
+ with self.session() as session:
902
+ return (
903
+ session.execute(
904
+ select(Participant)
905
+ .where(Participant.room_id == room_pk)
906
+ .where(Participant.nickname == nickname)
907
+ ).scalar()
908
+ is None
909
+ )
910
+
911
+ def set_participants_filled(self, room_pk: int, val=True) -> None:
912
+ with self.session() as session:
913
+ session.execute(
914
+ update(Room).where(Room.id == room_pk).values(participants_filled=val)
915
+ )
916
+ session.commit()
917
+
918
+ def set_history_filled(self, room_pk: int, val=True) -> None:
919
+ with self.session() as session:
920
+ session.execute(
921
+ update(Room).where(Room.id == room_pk).values(history_filled=True)
922
+ )
923
+ session.commit()
924
+
925
+ def get_all(self, user_pk: int) -> Iterator[Room]:
926
+ with self.session() as session:
927
+ yield from session.execute(
928
+ select(Room).where(Room.user_account_id == user_pk)
929
+ ).scalars()
930
+
931
+
932
+ class ParticipantStore(EngineMixin):
933
+ def __init__(self, *a, **kw):
934
+ super().__init__(*a, **kw)
935
+ with self.session() as session:
936
+ session.execute(delete(Participant))
937
+ session.execute(delete(Hat))
938
+ session.commit()
939
+
940
+ def add(self, room_pk: int, nickname: str) -> int:
941
+ with self.session() as session:
942
+ existing = session.execute(
943
+ select(Participant.id)
944
+ .where(Participant.room_id == room_pk)
945
+ .where(Participant.nickname == nickname)
946
+ ).scalar()
947
+ if existing is not None:
948
+ return existing
949
+ participant = Participant(room_id=room_pk, nickname=nickname)
950
+ session.add(participant)
951
+ session.commit()
952
+ return participant.id
953
+
954
+ def get_by_nickname(self, room_pk: int, nickname: str) -> Optional[Participant]:
955
+ with self.session() as session:
956
+ return session.execute(
957
+ select(Participant)
958
+ .where(Participant.room_id == room_pk)
959
+ .where(Participant.nickname == nickname)
960
+ ).scalar()
961
+
962
+ def get_by_resource(self, room_pk: int, resource: str) -> Optional[Participant]:
963
+ with self.session() as session:
964
+ return session.execute(
965
+ select(Participant)
966
+ .where(Participant.room_id == room_pk)
967
+ .where(Participant.resource == resource)
968
+ ).scalar()
969
+
970
+ def get_by_contact(self, room_pk: int, contact_pk: int) -> Optional[Participant]:
971
+ with self.session() as session:
972
+ return session.execute(
973
+ select(Participant)
974
+ .where(Participant.room_id == room_pk)
975
+ .where(Participant.contact_id == contact_pk)
976
+ ).scalar()
977
+
978
+ def get_all(self, room_pk: int, user_included=True) -> Iterator[Participant]:
979
+ with self.session() as session:
980
+ q = select(Participant).where(Participant.room_id == room_pk)
981
+ if not user_included:
982
+ q = q.where(~Participant.is_user)
983
+ yield from session.execute(q).scalars()
984
+
985
+ def get_for_contact(self, contact_pk: int) -> Iterator[Participant]:
986
+ with self.session() as session:
987
+ yield from session.execute(
988
+ select(Participant).where(Participant.contact_id == contact_pk)
989
+ ).scalars()
990
+
991
+ def update(self, participant: "LegacyParticipant") -> None:
992
+ with self.session() as session:
993
+ session.execute(
994
+ update(Participant)
995
+ .where(Participant.id == participant.pk)
996
+ .values(
997
+ resource=participant.jid.resource,
998
+ affiliation=participant.affiliation,
999
+ role=participant.role,
1000
+ presence_sent=participant._presence_sent, # type:ignore
1001
+ # hats=[self.add_hat(h.uri, h.title) for h in participant._hats],
1002
+ is_user=participant.is_user,
1003
+ contact_id=(
1004
+ None
1005
+ if participant.contact is None
1006
+ else participant.contact.contact_pk
1007
+ ),
1008
+ )
1009
+ )
1010
+ session.commit()
1011
+
1012
+ def add_hat(self, uri: str, title: str) -> Hat:
1013
+ with self.session() as session:
1014
+ existing = session.execute(
1015
+ select(Hat).where(Hat.uri == uri).where(Hat.title == title)
1016
+ ).scalar()
1017
+ if existing is not None:
1018
+ return existing
1019
+ hat = Hat(uri=uri, title=title)
1020
+ session.add(hat)
1021
+ session.commit()
1022
+ return hat
1023
+
1024
+ def set_presence_sent(self, participant_pk: int) -> None:
1025
+ with self.session() as session:
1026
+ session.execute(
1027
+ update(Participant)
1028
+ .where(Participant.id == participant_pk)
1029
+ .values(presence_sent=True)
1030
+ )
1031
+ session.commit()
1032
+
1033
+ def set_affiliation(self, participant_pk: int, affiliation: MucAffiliation) -> None:
1034
+ with self.session() as session:
1035
+ session.execute(
1036
+ update(Participant)
1037
+ .where(Participant.id == participant_pk)
1038
+ .values(affiliation=affiliation)
1039
+ )
1040
+ session.commit()
1041
+
1042
+ def set_role(self, participant_pk: int, role: MucRole) -> None:
1043
+ with self.session() as session:
1044
+ session.execute(
1045
+ update(Participant)
1046
+ .where(Participant.id == participant_pk)
1047
+ .values(role=role)
1048
+ )
1049
+ session.commit()
1050
+
1051
+ def set_hats(self, participant_pk: int, hats: list[HatTuple]) -> None:
1052
+ with self.session() as session:
1053
+ part = session.execute(
1054
+ select(Participant).where(Participant.id == participant_pk)
1055
+ ).scalar()
1056
+ if part is None:
1057
+ raise ValueError
1058
+ part.hats.clear()
1059
+ for h in hats:
1060
+ hat = self.add_hat(*h)
1061
+ if hat in part.hats:
1062
+ continue
1063
+ part.hats.append(hat)
1064
+ session.commit()
1065
+
1066
+ def delete(self, participant_pk: int) -> None:
1067
+ with self.session() as session:
1068
+ session.execute(delete(Participant).where(Participant.id == participant_pk))
1069
+
1070
+ def get_count(self, room_pk: int) -> int:
1071
+ with self.session() as session:
1072
+ return session.query(
1073
+ count(Participant.id).filter(Participant.room_id == room_pk)
1074
+ ).scalar()
1075
+
1076
+
1077
+ log = logging.getLogger(__name__)
1078
+ _session: Optional[Session] = None