slidge 0.1.3__py3-none-any.whl → 0.2.0a1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (74) hide show
  1. slidge/__init__.py +3 -5
  2. slidge/__main__.py +2 -196
  3. slidge/__version__.py +5 -0
  4. slidge/command/adhoc.py +8 -1
  5. slidge/command/admin.py +6 -7
  6. slidge/command/base.py +1 -2
  7. slidge/command/register.py +32 -16
  8. slidge/command/user.py +85 -6
  9. slidge/contact/contact.py +165 -49
  10. slidge/contact/roster.py +122 -47
  11. slidge/core/config.py +14 -11
  12. slidge/core/gateway/base.py +148 -36
  13. slidge/core/gateway/caps.py +7 -5
  14. slidge/core/gateway/disco.py +2 -4
  15. slidge/core/gateway/mam.py +1 -4
  16. slidge/core/gateway/muc_admin.py +1 -1
  17. slidge/core/gateway/ping.py +2 -3
  18. slidge/core/gateway/presence.py +1 -1
  19. slidge/core/gateway/registration.py +32 -21
  20. slidge/core/gateway/search.py +3 -5
  21. slidge/core/gateway/session_dispatcher.py +120 -57
  22. slidge/core/gateway/vcard_temp.py +7 -5
  23. slidge/core/mixins/__init__.py +11 -1
  24. slidge/core/mixins/attachment.py +32 -14
  25. slidge/core/mixins/avatar.py +90 -25
  26. slidge/core/mixins/base.py +8 -2
  27. slidge/core/mixins/db.py +18 -0
  28. slidge/core/mixins/disco.py +0 -10
  29. slidge/core/mixins/message.py +18 -8
  30. slidge/core/mixins/message_maker.py +17 -9
  31. slidge/core/mixins/presence.py +17 -4
  32. slidge/core/pubsub.py +54 -220
  33. slidge/core/session.py +69 -34
  34. slidge/db/__init__.py +4 -0
  35. slidge/db/alembic/env.py +64 -0
  36. slidge/db/alembic/script.py.mako +26 -0
  37. slidge/db/alembic/versions/09f27f098baa_add_missing_attributes_in_room.py +36 -0
  38. slidge/db/alembic/versions/2461390c0af2_store_contacts_caps_verstring_in_db.py +36 -0
  39. slidge/db/alembic/versions/29f5280c61aa_store_subject_setter_in_room.py +37 -0
  40. slidge/db/alembic/versions/2b1f45ab7379_store_room_subject_setter_by_nickname.py +41 -0
  41. slidge/db/alembic/versions/82a4af84b679_add_muc_history_filled.py +48 -0
  42. slidge/db/alembic/versions/8d2ced764698_rely_on_db_to_store_contacts_rooms_and_.py +133 -0
  43. slidge/db/alembic/versions/aa9d82a7f6ef_db_creation.py +85 -0
  44. slidge/db/alembic/versions/b33993e87db3_move_everything_to_persistent_db.py +214 -0
  45. slidge/db/alembic/versions/b64b1a793483_add_source_and_legacy_id_for_archived_.py +48 -0
  46. slidge/db/alembic/versions/c4a8ec35a0e8_per_room_user_nick.py +34 -0
  47. slidge/db/alembic/versions/e91195719c2c_store_users_avatars_persistently.py +26 -0
  48. slidge/db/avatar.py +235 -0
  49. slidge/db/meta.py +65 -0
  50. slidge/db/models.py +375 -0
  51. slidge/db/store.py +1078 -0
  52. slidge/group/archive.py +58 -14
  53. slidge/group/bookmarks.py +72 -57
  54. slidge/group/participant.py +87 -28
  55. slidge/group/room.py +369 -211
  56. slidge/main.py +201 -0
  57. slidge/migration.py +30 -0
  58. slidge/slixfix/__init__.py +35 -2
  59. slidge/slixfix/roster.py +11 -4
  60. slidge/slixfix/xep_0292/vcard4.py +3 -0
  61. slidge/util/archive_msg.py +2 -1
  62. slidge/util/db.py +1 -47
  63. slidge/util/test.py +71 -4
  64. slidge/util/types.py +29 -4
  65. slidge/util/util.py +22 -0
  66. {slidge-0.1.3.dist-info → slidge-0.2.0a1.dist-info}/METADATA +4 -4
  67. slidge-0.2.0a1.dist-info/RECORD +114 -0
  68. slidge/core/cache.py +0 -183
  69. slidge/util/schema.sql +0 -126
  70. slidge/util/sql.py +0 -508
  71. slidge-0.1.3.dist-info/RECORD +0 -96
  72. {slidge-0.1.3.dist-info → slidge-0.2.0a1.dist-info}/LICENSE +0 -0
  73. {slidge-0.1.3.dist-info → slidge-0.2.0a1.dist-info}/WHEEL +0 -0
  74. {slidge-0.1.3.dist-info → slidge-0.2.0a1.dist-info}/entry_points.txt +0 -0
slidge/db/store.py ADDED
@@ -0,0 +1,1078 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ from contextlib import contextmanager
6
+ from datetime import datetime, timedelta, timezone
7
+ from typing import TYPE_CHECKING, Collection, Iterator, Optional, Type
8
+
9
+ from slixmpp import JID, Iq, Message, Presence
10
+ from slixmpp.exceptions import XMPPError
11
+ from sqlalchemy import Engine, delete, select, update
12
+ from sqlalchemy.orm import Session, attributes
13
+ from sqlalchemy.sql.functions import count
14
+
15
+ from ..util.archive_msg import HistoryMessage
16
+ from ..util.types import URL, CachedPresence
17
+ from ..util.types import Hat as HatTuple
18
+ from ..util.types import MamMetadata, MucAffiliation, MucRole
19
+ from .meta import Base
20
+ from .models import (
21
+ ArchivedMessage,
22
+ ArchivedMessageSource,
23
+ Attachment,
24
+ Avatar,
25
+ Contact,
26
+ ContactSent,
27
+ GatewayUser,
28
+ Hat,
29
+ LegacyIdsMulti,
30
+ Participant,
31
+ Room,
32
+ XmppIdsMulti,
33
+ XmppToLegacyEnum,
34
+ XmppToLegacyIds,
35
+ )
36
+
37
+ if TYPE_CHECKING:
38
+ from ..contact.contact import LegacyContact
39
+ from ..group.participant import LegacyParticipant
40
+ from ..group.room import LegacyMUC
41
+
42
+
43
+ class EngineMixin:
44
+ def __init__(self, engine: Engine):
45
+ self._engine = engine
46
+
47
+ @contextmanager
48
+ def session(self, **session_kwargs) -> Iterator[Session]:
49
+ global _session
50
+ if _session is not None:
51
+ yield _session
52
+ return
53
+ with Session(self._engine, **session_kwargs) as session:
54
+ _session = session
55
+ try:
56
+ yield session
57
+ finally:
58
+ _session = None
59
+
60
+
61
+ class UpdatedMixin(EngineMixin):
62
+ model: Type[Base] = NotImplemented
63
+
64
+ def __init__(self, *a, **kw):
65
+ super().__init__(*a, **kw)
66
+ with self.session() as session:
67
+ session.execute(
68
+ delete(self.model).where(~self.model.updated) # type:ignore
69
+ )
70
+ session.execute(update(self.model).values(updated=False)) # type:ignore
71
+ session.commit()
72
+
73
+ def get_by_pk(self, pk: int) -> Optional[Base]:
74
+ with self.session() as session:
75
+ return session.execute(
76
+ select(self.model).where(self.model.id == pk) # type:ignore
77
+ ).scalar()
78
+
79
+
80
+ class SlidgeStore(EngineMixin):
81
+ def __init__(self, engine: Engine):
82
+ super().__init__(engine)
83
+ self.users = UserStore(engine)
84
+ self.avatars = AvatarStore(engine)
85
+ self.contacts = ContactStore(engine)
86
+ self.mam = MAMStore(engine)
87
+ self.multi = MultiStore(engine)
88
+ self.attachments = AttachmentStore(engine)
89
+ self.rooms = RoomStore(engine)
90
+ self.sent = SentStore(engine)
91
+ self.participants = ParticipantStore(engine)
92
+
93
+
94
+ class UserStore(EngineMixin):
95
+ def new(self, jid: JID, legacy_module_data: dict) -> GatewayUser:
96
+ if jid.resource:
97
+ jid = JID(jid.bare)
98
+ with self.session(expire_on_commit=False) as session:
99
+ user = session.execute(
100
+ select(GatewayUser).where(GatewayUser.jid == jid)
101
+ ).scalar()
102
+ if user is not None:
103
+ return user
104
+ user = GatewayUser(jid=jid, legacy_module_data=legacy_module_data)
105
+ session.add(user)
106
+ session.commit()
107
+ return user
108
+
109
+ def update(self, user: GatewayUser):
110
+ # https://github.com/sqlalchemy/sqlalchemy/discussions/6473
111
+ attributes.flag_modified(user, "legacy_module_data")
112
+ attributes.flag_modified(user, "preferences")
113
+ with self.session() as session:
114
+ session.add(user)
115
+ session.commit()
116
+
117
+ def get_all(self) -> Iterator[GatewayUser]:
118
+ with self.session() as session:
119
+ yield from session.execute(select(GatewayUser)).scalars()
120
+
121
+ def get(self, jid: JID) -> Optional[GatewayUser]:
122
+ with self.session() as session:
123
+ return session.execute(
124
+ select(GatewayUser).where(GatewayUser.jid == jid.bare)
125
+ ).scalar()
126
+
127
+ def get_by_stanza(self, stanza: Iq | Message | Presence) -> Optional[GatewayUser]:
128
+ return self.get(stanza.get_from())
129
+
130
+ def delete(self, jid: JID) -> None:
131
+ with self.session() as session:
132
+ session.delete(self.get(jid))
133
+ session.commit()
134
+
135
+
136
+ class AvatarStore(EngineMixin):
137
+ def get_by_url(self, url: URL | str) -> Optional[Avatar]:
138
+ with self.session() as session:
139
+ return session.execute(select(Avatar).where(Avatar.url == url)).scalar()
140
+
141
+ def get_by_legacy_id(self, legacy_id: str) -> Optional[Avatar]:
142
+ with self.session() as session:
143
+ return session.execute(
144
+ select(Avatar).where(Avatar.legacy_id == legacy_id)
145
+ ).scalar()
146
+
147
+ def get_by_pk(self, pk: int) -> Optional[Avatar]:
148
+ with self.session() as session:
149
+ return session.execute(select(Avatar).where(Avatar.id == pk)).scalar()
150
+
151
+ def delete_by_pk(self, pk: int):
152
+ with self.session() as session:
153
+ session.execute(delete(Avatar).where(Avatar.id == pk))
154
+ session.commit()
155
+
156
+ def get_all(self) -> Iterator[Avatar]:
157
+ with self.session() as session:
158
+ yield from session.execute(select(Avatar)).scalars()
159
+
160
+
161
+ class SentStore(EngineMixin):
162
+ def set_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
163
+ with self.session() as session:
164
+ msg = XmppToLegacyIds(
165
+ user_account_id=user_pk,
166
+ legacy_id=legacy_id,
167
+ xmpp_id=xmpp_id,
168
+ type=XmppToLegacyEnum.DM,
169
+ )
170
+ session.add(msg)
171
+ session.commit()
172
+
173
+ def get_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]:
174
+ with self.session() as session:
175
+ return session.execute(
176
+ select(XmppToLegacyIds.xmpp_id)
177
+ .where(XmppToLegacyIds.user_account_id == user_pk)
178
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
179
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM)
180
+ ).scalar()
181
+
182
+ def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
183
+ with self.session() as session:
184
+ return session.execute(
185
+ select(XmppToLegacyIds.legacy_id)
186
+ .where(XmppToLegacyIds.user_account_id == user_pk)
187
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
188
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM)
189
+ ).scalar()
190
+
191
+ def set_group_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
192
+ with self.session() as session:
193
+ msg = XmppToLegacyIds(
194
+ user_account_id=user_pk,
195
+ legacy_id=legacy_id,
196
+ xmpp_id=xmpp_id,
197
+ type=XmppToLegacyEnum.GROUP_CHAT,
198
+ )
199
+ session.add(msg)
200
+ session.commit()
201
+
202
+ def get_group_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]:
203
+ with self.session() as session:
204
+ return session.execute(
205
+ select(XmppToLegacyIds.xmpp_id)
206
+ .where(XmppToLegacyIds.user_account_id == user_pk)
207
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
208
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT)
209
+ ).scalar()
210
+
211
+ def get_group_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
212
+ with self.session() as session:
213
+ return session.execute(
214
+ select(XmppToLegacyIds.legacy_id)
215
+ .where(XmppToLegacyIds.user_account_id == user_pk)
216
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
217
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT)
218
+ ).scalar()
219
+
220
+ def set_thread(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
221
+ with self.session() as session:
222
+ msg = XmppToLegacyIds(
223
+ user_account_id=user_pk,
224
+ legacy_id=legacy_id,
225
+ xmpp_id=xmpp_id,
226
+ type=XmppToLegacyEnum.THREAD,
227
+ )
228
+ session.add(msg)
229
+ session.commit()
230
+
231
+ def get_legacy_thread(self, user_pk: int, xmpp_id: str) -> Optional[str]:
232
+ with self.session() as session:
233
+ return session.execute(
234
+ select(XmppToLegacyIds.legacy_id)
235
+ .where(XmppToLegacyIds.user_account_id == user_pk)
236
+ .where(XmppToLegacyIds.xmpp_id == xmpp_id)
237
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.THREAD)
238
+ ).scalar()
239
+
240
+ def get_xmpp_thread(self, user_pk: int, legacy_id: str) -> Optional[str]:
241
+ with self.session() as session:
242
+ return session.execute(
243
+ select(XmppToLegacyIds.xmpp_id)
244
+ .where(XmppToLegacyIds.user_account_id == user_pk)
245
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
246
+ .where(XmppToLegacyIds.type == XmppToLegacyEnum.THREAD)
247
+ ).scalar()
248
+
249
+ def was_sent_by_user(self, user_pk: int, legacy_id: str) -> bool:
250
+ with self.session() as session:
251
+ return (
252
+ session.execute(
253
+ select(XmppToLegacyIds.legacy_id)
254
+ .where(XmppToLegacyIds.user_account_id == user_pk)
255
+ .where(XmppToLegacyIds.legacy_id == legacy_id)
256
+ ).scalar()
257
+ is not None
258
+ )
259
+
260
+
261
+ class ContactStore(UpdatedMixin):
262
+ model = Contact
263
+
264
+ def __init__(self, *a, **k):
265
+ super().__init__(*a, **k)
266
+ with self.session() as session:
267
+ session.execute(update(Contact).values(cached_presence=False))
268
+ session.commit()
269
+
270
+ def get_all(self, user_pk: int) -> Iterator[Contact]:
271
+ with self.session() as session:
272
+ yield from session.execute(
273
+ select(Contact).where(Contact.user_account_id == user_pk)
274
+ ).scalars()
275
+
276
+ def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Contact]:
277
+ with self.session() as session:
278
+ return session.execute(
279
+ select(Contact)
280
+ .where(Contact.jid == jid.bare)
281
+ .where(Contact.user_account_id == user_pk)
282
+ ).scalar()
283
+
284
+ def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Contact]:
285
+ with self.session() as session:
286
+ return session.execute(
287
+ select(Contact)
288
+ .where(Contact.legacy_id == legacy_id)
289
+ .where(Contact.user_account_id == user_pk)
290
+ ).scalar()
291
+
292
+ def update_nick(self, contact_pk: int, nick: Optional[str]) -> None:
293
+ with self.session() as session:
294
+ session.execute(
295
+ update(Contact).where(Contact.id == contact_pk).values(nick=nick)
296
+ )
297
+ session.commit()
298
+
299
+ def get_presence(self, contact_pk: int) -> Optional[CachedPresence]:
300
+ with self.session() as session:
301
+ presence = session.execute(
302
+ select(
303
+ Contact.last_seen,
304
+ Contact.ptype,
305
+ Contact.pstatus,
306
+ Contact.pshow,
307
+ Contact.cached_presence,
308
+ ).where(Contact.id == contact_pk)
309
+ ).first()
310
+ if presence is None or not presence[-1]:
311
+ return None
312
+ return CachedPresence(*presence[:-1])
313
+
314
+ def set_presence(self, contact_pk: int, presence: CachedPresence) -> None:
315
+ with self.session() as session:
316
+ session.execute(
317
+ update(Contact)
318
+ .where(Contact.id == contact_pk)
319
+ .values(**presence._asdict(), cached_presence=True)
320
+ )
321
+ session.commit()
322
+
323
+ def reset_presence(self, contact_pk: int):
324
+ with self.session() as session:
325
+ session.execute(
326
+ update(Contact)
327
+ .where(Contact.id == contact_pk)
328
+ .values(
329
+ last_seen=None,
330
+ ptype=None,
331
+ pstatus=None,
332
+ pshow=None,
333
+ cached_presence=False,
334
+ )
335
+ )
336
+ session.commit()
337
+
338
+ def set_avatar(self, contact_pk: int, avatar_pk: Optional[int]):
339
+ with self.session() as session:
340
+ session.execute(
341
+ update(Contact)
342
+ .where(Contact.id == contact_pk)
343
+ .values(avatar_id=avatar_pk)
344
+ )
345
+ session.commit()
346
+
347
+ def get_avatar_legacy_id(self, contact_pk: int) -> Optional[str]:
348
+ with self.session() as session:
349
+ contact = session.execute(
350
+ select(Contact).where(Contact.id == contact_pk)
351
+ ).scalar()
352
+ if contact is None or contact.avatar is None:
353
+ return None
354
+ return contact.avatar.legacy_id
355
+
356
+ def update(self, contact: "LegacyContact", commit=True) -> int:
357
+ with self.session() as session:
358
+ if contact.contact_pk is None:
359
+ if contact.cached_presence is not None:
360
+ presence_kwargs = contact.cached_presence._asdict()
361
+ presence_kwargs["cached_presence"] = True
362
+ else:
363
+ presence_kwargs = {}
364
+ row = Contact(
365
+ jid=contact.jid.bare,
366
+ legacy_id=str(contact.legacy_id),
367
+ user_account_id=contact.user_pk,
368
+ **presence_kwargs,
369
+ )
370
+ else:
371
+ row = (
372
+ session.query(Contact)
373
+ .filter(Contact.id == contact.contact_pk)
374
+ .one()
375
+ )
376
+ row.nick = contact.name
377
+ row.is_friend = contact.is_friend
378
+ row.added_to_roster = contact.added_to_roster
379
+ row.updated = True
380
+ row.extra_attributes = contact.serialize_extra_attributes()
381
+ row.caps_ver = contact._caps_ver
382
+ session.add(row)
383
+ if commit:
384
+ session.commit()
385
+ return row.id
386
+
387
+ def add_to_sent(self, contact_pk: int, msg_id: str) -> None:
388
+ with self.session() as session:
389
+ new = ContactSent(contact_id=contact_pk, msg_id=msg_id)
390
+ session.add(new)
391
+ session.commit()
392
+
393
+ def pop_sent_up_to(self, contact_pk: int, msg_id: str) -> list[str]:
394
+ result = []
395
+ to_del = []
396
+ with self.session() as session:
397
+ for row in session.execute(
398
+ select(ContactSent)
399
+ .where(ContactSent.contact_id == contact_pk)
400
+ .order_by(ContactSent.id)
401
+ ).scalars():
402
+ to_del.append(row.id)
403
+ result.append(row.msg_id)
404
+ if row.msg_id == msg_id:
405
+ break
406
+ for row_id in to_del:
407
+ session.execute(delete(ContactSent).where(ContactSent.id == row_id))
408
+ return result
409
+
410
+ def set_friend(self, contact_pk: int, is_friend: bool) -> None:
411
+ with self.session() as session:
412
+ session.execute(
413
+ update(Contact)
414
+ .where(Contact.id == contact_pk)
415
+ .values(is_friend=is_friend)
416
+ )
417
+ session.commit()
418
+
419
+ def set_added_to_roster(self, contact_pk: int, value: bool) -> None:
420
+ with self.session() as session:
421
+ session.execute(
422
+ update(Contact)
423
+ .where(Contact.id == contact_pk)
424
+ .values(added_to_roster=value)
425
+ )
426
+ session.commit()
427
+
428
+ def delete(self, contact_pk: int) -> None:
429
+ with self.session() as session:
430
+ session.execute(delete(Contact).where(Contact.id == contact_pk))
431
+ session.commit()
432
+
433
+
434
+ class MAMStore(EngineMixin):
435
+ def __init__(self, *a, **kw):
436
+ super().__init__(*a, **kw)
437
+ with self.session() as session:
438
+ session.execute(
439
+ update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL)
440
+ )
441
+ session.commit()
442
+
443
+ def nuke_older_than(self, days: int) -> None:
444
+ with self.session() as session:
445
+ session.execute(
446
+ delete(ArchivedMessage).where(
447
+ ArchivedMessage.timestamp < datetime.now() - timedelta(days=days)
448
+ )
449
+ )
450
+ session.commit()
451
+
452
+ def add_message(
453
+ self,
454
+ room_pk: int,
455
+ message: HistoryMessage,
456
+ archive_only: bool,
457
+ legacy_msg_id: str | None,
458
+ ) -> None:
459
+ with self.session() as session:
460
+ source = (
461
+ ArchivedMessageSource.BACKFILL
462
+ if archive_only
463
+ else ArchivedMessageSource.LIVE
464
+ )
465
+ existing = session.execute(
466
+ select(ArchivedMessage)
467
+ .where(ArchivedMessage.room_id == room_pk)
468
+ .where(ArchivedMessage.stanza_id == message.id)
469
+ ).scalar()
470
+ if existing is not None:
471
+ log.debug("Updating message %s in room %s", message.id, room_pk)
472
+ existing.timestamp = message.when
473
+ existing.stanza = str(message.stanza)
474
+ existing.author_jid = message.stanza.get_from()
475
+ existing.source = source
476
+ existing.legacy_id = legacy_msg_id
477
+ session.add(existing)
478
+ session.commit()
479
+ return
480
+ mam_msg = ArchivedMessage(
481
+ stanza_id=message.id,
482
+ timestamp=message.when,
483
+ stanza=str(message.stanza),
484
+ author_jid=message.stanza.get_from(),
485
+ room_id=room_pk,
486
+ source=source,
487
+ legacy_id=legacy_msg_id,
488
+ )
489
+ session.add(mam_msg)
490
+ session.commit()
491
+
492
+ def get_messages(
493
+ self,
494
+ room_pk: int,
495
+ start_date: Optional[datetime] = None,
496
+ end_date: Optional[datetime] = None,
497
+ before_id: Optional[str] = None,
498
+ after_id: Optional[str] = None,
499
+ ids: Collection[str] = (),
500
+ last_page_n: Optional[int] = None,
501
+ sender: Optional[str] = None,
502
+ flip=False,
503
+ ) -> Iterator[HistoryMessage]:
504
+
505
+ with self.session() as session:
506
+ q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
507
+ if start_date is not None:
508
+ q = q.where(ArchivedMessage.timestamp >= start_date)
509
+ if end_date is not None:
510
+ q = q.where(ArchivedMessage.timestamp <= end_date)
511
+ if before_id is not None:
512
+ stamp = session.execute(
513
+ select(ArchivedMessage.timestamp).where(
514
+ ArchivedMessage.stanza_id == before_id
515
+ )
516
+ ).scalar()
517
+ if stamp is None:
518
+ raise XMPPError(
519
+ "item-not-found",
520
+ f"Message {before_id} not found",
521
+ )
522
+ q = q.where(ArchivedMessage.timestamp < stamp)
523
+ if after_id is not None:
524
+ stamp = session.execute(
525
+ select(ArchivedMessage.timestamp).where(
526
+ ArchivedMessage.stanza_id == after_id
527
+ )
528
+ ).scalar()
529
+ if stamp is None:
530
+ raise XMPPError(
531
+ "item-not-found",
532
+ f"Message {after_id} not found",
533
+ )
534
+ q = q.where(ArchivedMessage.timestamp > stamp)
535
+ if ids:
536
+ q = q.filter(ArchivedMessage.stanza_id.in_(ids))
537
+ if sender is not None:
538
+ q = q.where(ArchivedMessage.author_jid == sender)
539
+ if flip:
540
+ q = q.order_by(ArchivedMessage.timestamp.desc())
541
+ else:
542
+ q = q.order_by(ArchivedMessage.timestamp.asc())
543
+ msgs = list(session.execute(q).scalars())
544
+ if ids and len(msgs) != len(ids):
545
+ raise XMPPError(
546
+ "item-not-found",
547
+ "One of the requested messages IDs could not be found "
548
+ "with the given constraints.",
549
+ )
550
+ if last_page_n is not None:
551
+ if flip:
552
+ msgs = msgs[:last_page_n]
553
+ else:
554
+ msgs = msgs[-last_page_n:]
555
+ for h in msgs:
556
+ yield HistoryMessage(
557
+ stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=timezone.utc)
558
+ )
559
+
560
+ def get_first(self, room_pk: int, with_legacy_id=False) -> ArchivedMessage | None:
561
+ with self.session() as session:
562
+ q = (
563
+ select(ArchivedMessage)
564
+ .where(ArchivedMessage.room_id == room_pk)
565
+ .order_by(ArchivedMessage.timestamp.asc())
566
+ )
567
+ if with_legacy_id:
568
+ q = q.filter(ArchivedMessage.legacy_id.isnot(None))
569
+ return session.execute(q).scalar()
570
+
571
+ def get_last(
572
+ self, room_pk: int, source: ArchivedMessageSource | None = None
573
+ ) -> ArchivedMessage | None:
574
+ with self.session() as session:
575
+ q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
576
+
577
+ if source is not None:
578
+ q = q.where(ArchivedMessage.source == source)
579
+
580
+ return session.execute(
581
+ q.order_by(ArchivedMessage.timestamp.desc())
582
+ ).scalar()
583
+
584
+ def get_first_and_last(self, room_pk: int) -> list[MamMetadata]:
585
+ r = []
586
+ with self.session():
587
+ first = self.get_first(room_pk)
588
+ if first is not None:
589
+ r.append(MamMetadata(first.stanza_id, first.timestamp))
590
+ last = self.get_last(room_pk)
591
+ if last is not None:
592
+ r.append(MamMetadata(last.stanza_id, last.timestamp))
593
+ return r
594
+
595
+ def get_most_recent_with_legacy_id(
596
+ self, room_pk: int, source: ArchivedMessageSource | None = None
597
+ ) -> ArchivedMessage | None:
598
+ with self.session() as session:
599
+ q = (
600
+ select(ArchivedMessage)
601
+ .where(ArchivedMessage.room_id == room_pk)
602
+ .where(ArchivedMessage.legacy_id.isnot(None))
603
+ )
604
+ if source is not None:
605
+ q = q.where(ArchivedMessage.source == source)
606
+ return session.execute(
607
+ q.order_by(ArchivedMessage.timestamp.desc())
608
+ ).scalar()
609
+
610
+ def get_least_recent_with_legacy_id_after(
611
+ self, room_pk: int, after_id: str, source=ArchivedMessageSource.LIVE
612
+ ) -> ArchivedMessage | None:
613
+ with self.session() as session:
614
+ after_timestamp = (
615
+ session.query(ArchivedMessage.timestamp)
616
+ .filter(ArchivedMessage.legacy_id == after_id)
617
+ .scalar()
618
+ )
619
+ q = (
620
+ select(ArchivedMessage)
621
+ .where(ArchivedMessage.room_id == room_pk)
622
+ .where(ArchivedMessage.legacy_id.isnot(None))
623
+ .where(ArchivedMessage.source == source)
624
+ .where(ArchivedMessage.timestamp > after_timestamp)
625
+ )
626
+ return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar()
627
+
628
+ def get_by_legacy_id(self, room_pk: int, legacy_id: str) -> ArchivedMessage | None:
629
+ with self.session() as session:
630
+ return (
631
+ session.query(ArchivedMessage)
632
+ .filter(ArchivedMessage.room_id == room_pk)
633
+ .filter(ArchivedMessage.legacy_id == legacy_id)
634
+ .first()
635
+ )
636
+
637
+
638
+ class MultiStore(EngineMixin):
639
+ def get_xmpp_ids(self, user_pk: int, xmpp_id: str) -> list[str]:
640
+ with self.session() as session:
641
+ multi = session.execute(
642
+ select(XmppIdsMulti)
643
+ .where(XmppIdsMulti.xmpp_id == xmpp_id)
644
+ .where(XmppIdsMulti.user_account_id == user_pk)
645
+ ).scalar()
646
+ if multi is None:
647
+ return []
648
+ return [m.xmpp_id for m in multi.legacy_ids_multi.xmpp_ids]
649
+
650
+ def set_xmpp_ids(
651
+ self, user_pk: int, legacy_msg_id: str, xmpp_ids: list[str], fail=False
652
+ ) -> None:
653
+ with self.session() as session:
654
+ existing = session.execute(
655
+ select(LegacyIdsMulti)
656
+ .where(LegacyIdsMulti.user_account_id == user_pk)
657
+ .where(LegacyIdsMulti.legacy_id == legacy_msg_id)
658
+ ).scalar()
659
+ if existing is not None:
660
+ if fail:
661
+ raise
662
+ log.warning("Resetting multi for %s", legacy_msg_id)
663
+ session.execute(
664
+ delete(LegacyIdsMulti)
665
+ .where(LegacyIdsMulti.user_account_id == user_pk)
666
+ .where(LegacyIdsMulti.legacy_id == legacy_msg_id)
667
+ )
668
+ for i in xmpp_ids:
669
+ session.execute(
670
+ delete(XmppIdsMulti)
671
+ .where(XmppIdsMulti.user_account_id == user_pk)
672
+ .where(XmppIdsMulti.xmpp_id == i)
673
+ )
674
+ session.commit()
675
+ self.set_xmpp_ids(user_pk, legacy_msg_id, xmpp_ids, True)
676
+ return
677
+
678
+ row = LegacyIdsMulti(
679
+ user_account_id=user_pk,
680
+ legacy_id=legacy_msg_id,
681
+ xmpp_ids=[
682
+ XmppIdsMulti(user_account_id=user_pk, xmpp_id=i)
683
+ for i in xmpp_ids
684
+ if i
685
+ ],
686
+ )
687
+ session.add(row)
688
+ session.commit()
689
+
690
+ def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
691
+ with self.session() as session:
692
+ multi = session.execute(
693
+ select(XmppIdsMulti)
694
+ .where(XmppIdsMulti.xmpp_id == xmpp_id)
695
+ .where(XmppIdsMulti.user_account_id == user_pk)
696
+ ).scalar()
697
+ if multi is None:
698
+ return None
699
+ return multi.legacy_ids_multi.legacy_id
700
+
701
+
702
+ class AttachmentStore(EngineMixin):
703
+ def get_url(self, legacy_file_id: str) -> Optional[str]:
704
+ with self.session() as session:
705
+ return session.execute(
706
+ select(Attachment.url).where(
707
+ Attachment.legacy_file_id == legacy_file_id
708
+ )
709
+ ).scalar()
710
+
711
+ def set_url(self, user_pk: int, legacy_file_id: str, url: str) -> None:
712
+ with self.session() as session:
713
+ att = session.execute(
714
+ select(Attachment)
715
+ .where(Attachment.legacy_file_id == legacy_file_id)
716
+ .where(Attachment.user_account_id == user_pk)
717
+ ).scalar()
718
+ if att is None:
719
+ att = Attachment(
720
+ legacy_file_id=legacy_file_id, url=url, user_account_id=user_pk
721
+ )
722
+ session.add(att)
723
+ else:
724
+ att.url = url
725
+ session.commit()
726
+
727
+ def get_sims(self, url: str) -> Optional[str]:
728
+ with self.session() as session:
729
+ return session.execute(
730
+ select(Attachment.sims).where(Attachment.url == url)
731
+ ).scalar()
732
+
733
+ def set_sims(self, url: str, sims: str) -> None:
734
+ with self.session() as session:
735
+ session.execute(
736
+ update(Attachment).where(Attachment.url == url).values(sims=sims)
737
+ )
738
+ session.commit()
739
+
740
+ def get_sfs(self, url: str) -> Optional[str]:
741
+ with self.session() as session:
742
+ return session.execute(
743
+ select(Attachment.sfs).where(Attachment.url == url)
744
+ ).scalar()
745
+
746
+ def set_sfs(self, url: str, sfs: str) -> None:
747
+ with self.session() as session:
748
+ session.execute(
749
+ update(Attachment).where(Attachment.url == url).values(sfs=sfs)
750
+ )
751
+ session.commit()
752
+
753
+ def remove(self, legacy_file_id: str) -> None:
754
+ with self.session() as session:
755
+ session.execute(
756
+ delete(Attachment).where(Attachment.legacy_file_id == legacy_file_id)
757
+ )
758
+ session.commit()
759
+
760
+
761
+ class RoomStore(UpdatedMixin):
762
+ model = Room
763
+
764
+ def __init__(self, *a, **kw):
765
+ super().__init__(*a, **kw)
766
+ with self.session() as session:
767
+ session.execute(
768
+ update(Room).values(
769
+ subject_setter=None, user_resources=None, history_filled=False
770
+ )
771
+ )
772
+ session.commit()
773
+
774
+ def set_avatar(self, room_pk: int, avatar_pk: int | None) -> None:
775
+ with self.session() as session:
776
+ session.execute(
777
+ update(Room).where(Room.id == room_pk).values(avatar_id=avatar_pk)
778
+ )
779
+ session.commit()
780
+
781
+ def get_avatar_legacy_id(self, room_pk: int) -> Optional[str]:
782
+ with self.session() as session:
783
+ room = session.execute(select(Room).where(Room.id == room_pk)).scalar()
784
+ if room is None or room.avatar is None:
785
+ return None
786
+ return room.avatar.legacy_id
787
+
788
+ def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Room]:
789
+ if jid.resource:
790
+ raise TypeError
791
+ with self.session() as session:
792
+ return session.execute(
793
+ select(Room)
794
+ .where(Room.user_account_id == user_pk)
795
+ .where(Room.jid == jid)
796
+ ).scalar()
797
+
798
+ def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Room]:
799
+ with self.session() as session:
800
+ return session.execute(
801
+ select(Room)
802
+ .where(Room.user_account_id == user_pk)
803
+ .where(Room.legacy_id == legacy_id)
804
+ ).scalar()
805
+
806
+ def update_subject_setter(self, room_pk: int, subject_setter: str | None):
807
+ with self.session() as session:
808
+ session.execute(
809
+ update(Room)
810
+ .where(Room.id == room_pk)
811
+ .values(subject_setter=subject_setter)
812
+ )
813
+ session.commit()
814
+
815
+ def update(self, muc: "LegacyMUC") -> int:
816
+ with self.session() as session:
817
+ if muc.pk is None:
818
+ row = Room(
819
+ jid=muc.jid,
820
+ legacy_id=str(muc.legacy_id),
821
+ user_account_id=muc.user_pk,
822
+ )
823
+ else:
824
+ row = session.query(Room).filter(Room.id == muc.pk).one()
825
+
826
+ row.updated = True
827
+ row.extra_attributes = muc.serialize_extra_attributes()
828
+ row.name = muc.name
829
+ row.description = muc.description
830
+ row.user_resources = (
831
+ None
832
+ if not muc._user_resources
833
+ else json.dumps(list(muc._user_resources))
834
+ )
835
+ row.muc_type = muc.type
836
+ row.subject = muc.subject
837
+ row.subject_date = muc.subject_date
838
+ row.subject_setter = muc.subject_setter
839
+ row.participants_filled = muc._participants_filled
840
+ row.n_participants = muc._n_participants
841
+ row.user_nick = muc.user_nick
842
+ session.add(row)
843
+ session.commit()
844
+ return row.id
845
+
846
+ def update_subject_date(
847
+ self, room_pk: int, subject_date: Optional[datetime]
848
+ ) -> None:
849
+ with self.session() as session:
850
+ session.execute(
851
+ update(Room).where(Room.id == room_pk).values(subject_date=subject_date)
852
+ )
853
+ session.commit()
854
+
855
+ def update_subject(self, room_pk: int, subject: Optional[str]) -> None:
856
+ with self.session() as session:
857
+ session.execute(
858
+ update(Room).where(Room.id == room_pk).values(subject=subject)
859
+ )
860
+ session.commit()
861
+
862
+ def update_description(self, room_pk: int, desc: Optional[str]) -> None:
863
+ with self.session() as session:
864
+ session.execute(
865
+ update(Room).where(Room.id == room_pk).values(description=desc)
866
+ )
867
+ session.commit()
868
+
869
+ def update_name(self, room_pk: int, name: Optional[str]) -> None:
870
+ with self.session() as session:
871
+ session.execute(update(Room).where(Room.id == room_pk).values(name=name))
872
+ session.commit()
873
+
874
+ def update_n_participants(self, room_pk: int, n: Optional[int]) -> None:
875
+ with self.session() as session:
876
+ session.execute(
877
+ update(Room).where(Room.id == room_pk).values(n_participants=n)
878
+ )
879
+ session.commit()
880
+
881
+ def delete(self, room_pk: int) -> None:
882
+ with self.session() as session:
883
+ session.execute(delete(Room).where(Room.id == room_pk))
884
+ session.execute(delete(Participant).where(Participant.room_id == room_pk))
885
+ session.commit()
886
+
887
+ def set_resource(self, room_pk: int, resources: set[str]) -> None:
888
+ with self.session() as session:
889
+ session.execute(
890
+ update(Room)
891
+ .where(Room.id == room_pk)
892
+ .values(
893
+ user_resources=(
894
+ None if not resources else json.dumps(list(resources))
895
+ )
896
+ )
897
+ )
898
+ session.commit()
899
+
900
+ def nickname_is_available(self, room_pk: int, nickname: str) -> bool:
901
+ with self.session() as session:
902
+ return (
903
+ session.execute(
904
+ select(Participant)
905
+ .where(Participant.room_id == room_pk)
906
+ .where(Participant.nickname == nickname)
907
+ ).scalar()
908
+ is None
909
+ )
910
+
911
+ def set_participants_filled(self, room_pk: int, val=True) -> None:
912
+ with self.session() as session:
913
+ session.execute(
914
+ update(Room).where(Room.id == room_pk).values(participants_filled=val)
915
+ )
916
+ session.commit()
917
+
918
+ def set_history_filled(self, room_pk: int, val=True) -> None:
919
+ with self.session() as session:
920
+ session.execute(
921
+ update(Room).where(Room.id == room_pk).values(history_filled=True)
922
+ )
923
+ session.commit()
924
+
925
+ def get_all(self, user_pk: int) -> Iterator[Room]:
926
+ with self.session() as session:
927
+ yield from session.execute(
928
+ select(Room).where(Room.user_account_id == user_pk)
929
+ ).scalars()
930
+
931
+
932
+ class ParticipantStore(EngineMixin):
933
+ def __init__(self, *a, **kw):
934
+ super().__init__(*a, **kw)
935
+ with self.session() as session:
936
+ session.execute(delete(Participant))
937
+ session.execute(delete(Hat))
938
+ session.commit()
939
+
940
+ def add(self, room_pk: int, nickname: str) -> int:
941
+ with self.session() as session:
942
+ existing = session.execute(
943
+ select(Participant.id)
944
+ .where(Participant.room_id == room_pk)
945
+ .where(Participant.nickname == nickname)
946
+ ).scalar()
947
+ if existing is not None:
948
+ return existing
949
+ participant = Participant(room_id=room_pk, nickname=nickname)
950
+ session.add(participant)
951
+ session.commit()
952
+ return participant.id
953
+
954
+ def get_by_nickname(self, room_pk: int, nickname: str) -> Optional[Participant]:
955
+ with self.session() as session:
956
+ return session.execute(
957
+ select(Participant)
958
+ .where(Participant.room_id == room_pk)
959
+ .where(Participant.nickname == nickname)
960
+ ).scalar()
961
+
962
+ def get_by_resource(self, room_pk: int, resource: str) -> Optional[Participant]:
963
+ with self.session() as session:
964
+ return session.execute(
965
+ select(Participant)
966
+ .where(Participant.room_id == room_pk)
967
+ .where(Participant.resource == resource)
968
+ ).scalar()
969
+
970
+ def get_by_contact(self, room_pk: int, contact_pk: int) -> Optional[Participant]:
971
+ with self.session() as session:
972
+ return session.execute(
973
+ select(Participant)
974
+ .where(Participant.room_id == room_pk)
975
+ .where(Participant.contact_id == contact_pk)
976
+ ).scalar()
977
+
978
+ def get_all(self, room_pk: int, user_included=True) -> Iterator[Participant]:
979
+ with self.session() as session:
980
+ q = select(Participant).where(Participant.room_id == room_pk)
981
+ if not user_included:
982
+ q = q.where(~Participant.is_user)
983
+ yield from session.execute(q).scalars()
984
+
985
+ def get_for_contact(self, contact_pk: int) -> Iterator[Participant]:
986
+ with self.session() as session:
987
+ yield from session.execute(
988
+ select(Participant).where(Participant.contact_id == contact_pk)
989
+ ).scalars()
990
+
991
+ def update(self, participant: "LegacyParticipant") -> None:
992
+ with self.session() as session:
993
+ session.execute(
994
+ update(Participant)
995
+ .where(Participant.id == participant.pk)
996
+ .values(
997
+ resource=participant.jid.resource,
998
+ affiliation=participant.affiliation,
999
+ role=participant.role,
1000
+ presence_sent=participant._presence_sent, # type:ignore
1001
+ # hats=[self.add_hat(h.uri, h.title) for h in participant._hats],
1002
+ is_user=participant.is_user,
1003
+ contact_id=(
1004
+ None
1005
+ if participant.contact is None
1006
+ else participant.contact.contact_pk
1007
+ ),
1008
+ )
1009
+ )
1010
+ session.commit()
1011
+
1012
+ def add_hat(self, uri: str, title: str) -> Hat:
1013
+ with self.session() as session:
1014
+ existing = session.execute(
1015
+ select(Hat).where(Hat.uri == uri).where(Hat.title == title)
1016
+ ).scalar()
1017
+ if existing is not None:
1018
+ return existing
1019
+ hat = Hat(uri=uri, title=title)
1020
+ session.add(hat)
1021
+ session.commit()
1022
+ return hat
1023
+
1024
+ def set_presence_sent(self, participant_pk: int) -> None:
1025
+ with self.session() as session:
1026
+ session.execute(
1027
+ update(Participant)
1028
+ .where(Participant.id == participant_pk)
1029
+ .values(presence_sent=True)
1030
+ )
1031
+ session.commit()
1032
+
1033
+ def set_affiliation(self, participant_pk: int, affiliation: MucAffiliation) -> None:
1034
+ with self.session() as session:
1035
+ session.execute(
1036
+ update(Participant)
1037
+ .where(Participant.id == participant_pk)
1038
+ .values(affiliation=affiliation)
1039
+ )
1040
+ session.commit()
1041
+
1042
+ def set_role(self, participant_pk: int, role: MucRole) -> None:
1043
+ with self.session() as session:
1044
+ session.execute(
1045
+ update(Participant)
1046
+ .where(Participant.id == participant_pk)
1047
+ .values(role=role)
1048
+ )
1049
+ session.commit()
1050
+
1051
+ def set_hats(self, participant_pk: int, hats: list[HatTuple]) -> None:
1052
+ with self.session() as session:
1053
+ part = session.execute(
1054
+ select(Participant).where(Participant.id == participant_pk)
1055
+ ).scalar()
1056
+ if part is None:
1057
+ raise ValueError
1058
+ part.hats.clear()
1059
+ for h in hats:
1060
+ hat = self.add_hat(*h)
1061
+ if hat in part.hats:
1062
+ continue
1063
+ part.hats.append(hat)
1064
+ session.commit()
1065
+
1066
+ def delete(self, participant_pk: int) -> None:
1067
+ with self.session() as session:
1068
+ session.execute(delete(Participant).where(Participant.id == participant_pk))
1069
+
1070
+ def get_count(self, room_pk: int) -> int:
1071
+ with self.session() as session:
1072
+ return session.query(
1073
+ count(Participant.id).filter(Participant.room_id == room_pk)
1074
+ ).scalar()
1075
+
1076
+
1077
+ log = logging.getLogger(__name__)
1078
+ _session: Optional[Session] = None