slidge 0.2.12__py3-none-any.whl → 0.3.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. slidge/__init__.py +5 -2
  2. slidge/command/adhoc.py +9 -3
  3. slidge/command/admin.py +16 -12
  4. slidge/command/base.py +16 -12
  5. slidge/command/chat_command.py +25 -16
  6. slidge/command/user.py +7 -8
  7. slidge/contact/contact.py +119 -209
  8. slidge/contact/roster.py +106 -105
  9. slidge/core/config.py +2 -43
  10. slidge/core/dispatcher/caps.py +9 -2
  11. slidge/core/dispatcher/disco.py +13 -3
  12. slidge/core/dispatcher/message/__init__.py +1 -1
  13. slidge/core/dispatcher/message/chat_state.py +17 -8
  14. slidge/core/dispatcher/message/marker.py +7 -5
  15. slidge/core/dispatcher/message/message.py +117 -92
  16. slidge/core/dispatcher/muc/__init__.py +1 -1
  17. slidge/core/dispatcher/muc/admin.py +4 -4
  18. slidge/core/dispatcher/muc/mam.py +10 -6
  19. slidge/core/dispatcher/muc/misc.py +4 -2
  20. slidge/core/dispatcher/muc/owner.py +5 -3
  21. slidge/core/dispatcher/muc/ping.py +3 -1
  22. slidge/core/dispatcher/presence.py +21 -15
  23. slidge/core/dispatcher/registration.py +20 -12
  24. slidge/core/dispatcher/search.py +7 -3
  25. slidge/core/dispatcher/session_dispatcher.py +13 -5
  26. slidge/core/dispatcher/util.py +37 -27
  27. slidge/core/dispatcher/vcard.py +7 -4
  28. slidge/core/gateway.py +168 -84
  29. slidge/core/mixins/__init__.py +1 -11
  30. slidge/core/mixins/attachment.py +163 -148
  31. slidge/core/mixins/avatar.py +100 -177
  32. slidge/core/mixins/db.py +50 -2
  33. slidge/core/mixins/message.py +19 -17
  34. slidge/core/mixins/message_maker.py +29 -15
  35. slidge/core/mixins/message_text.py +38 -30
  36. slidge/core/mixins/presence.py +91 -35
  37. slidge/core/pubsub.py +42 -47
  38. slidge/core/session.py +88 -57
  39. slidge/db/alembic/versions/0337c90c0b96_unify_legacy_xmpp_id_mappings.py +183 -0
  40. slidge/db/alembic/versions/4dbd23a3f868_new_avatar_store.py +56 -0
  41. slidge/db/alembic/versions/54ce3cde350c_use_hash_for_avatar_filenames.py +50 -0
  42. slidge/db/alembic/versions/58b98dacf819_refactor.py +118 -0
  43. slidge/db/alembic/versions/75a62b74b239_ditch_hats_table.py +74 -0
  44. slidge/db/avatar.py +150 -119
  45. slidge/db/meta.py +33 -22
  46. slidge/db/models.py +68 -117
  47. slidge/db/store.py +412 -1094
  48. slidge/group/archive.py +61 -54
  49. slidge/group/bookmarks.py +74 -55
  50. slidge/group/participant.py +135 -142
  51. slidge/group/room.py +315 -312
  52. slidge/main.py +28 -18
  53. slidge/migration.py +2 -12
  54. slidge/slixfix/__init__.py +20 -4
  55. slidge/slixfix/delivery_receipt.py +6 -4
  56. slidge/slixfix/link_preview/link_preview.py +1 -1
  57. slidge/slixfix/link_preview/stanza.py +1 -1
  58. slidge/slixfix/roster.py +5 -7
  59. slidge/slixfix/xep_0077/register.py +8 -8
  60. slidge/slixfix/xep_0077/stanza.py +7 -7
  61. slidge/slixfix/xep_0100/gateway.py +12 -13
  62. slidge/slixfix/xep_0153/vcard_avatar.py +1 -1
  63. slidge/slixfix/xep_0292/vcard4.py +1 -1
  64. slidge/util/archive_msg.py +11 -5
  65. slidge/util/conf.py +23 -20
  66. slidge/util/jid_escaping.py +1 -1
  67. slidge/{core/mixins → util}/lock.py +6 -6
  68. slidge/util/test.py +30 -29
  69. slidge/util/types.py +22 -18
  70. slidge/util/util.py +19 -22
  71. {slidge-0.2.12.dist-info → slidge-0.3.0a0.dist-info}/METADATA +1 -1
  72. slidge-0.3.0a0.dist-info/RECORD +117 -0
  73. {slidge-0.2.12.dist-info → slidge-0.3.0a0.dist-info}/WHEEL +1 -1
  74. slidge-0.2.12.dist-info/RECORD +0 -112
  75. {slidge-0.2.12.dist-info → slidge-0.3.0a0.dist-info}/entry_points.txt +0 -0
  76. {slidge-0.2.12.dist-info → slidge-0.3.0a0.dist-info}/licenses/LICENSE +0 -0
  77. {slidge-0.2.12.dist-info → slidge-0.3.0a0.dist-info}/top_level.txt +0 -0
slidge/db/store.py CHANGED
@@ -1,550 +1,311 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import hashlib
4
- import json
5
4
  import logging
6
5
  import uuid
7
- from contextlib import contextmanager
8
6
  from datetime import datetime, timedelta, timezone
9
7
  from mimetypes import guess_extension
10
- from typing import TYPE_CHECKING, Collection, Iterator, Optional, Type
8
+ from typing import Collection, Iterator, Optional, Type
11
9
 
12
- from slixmpp import JID, Iq, Message, Presence
13
10
  from slixmpp.exceptions import XMPPError
14
11
  from slixmpp.plugins.xep_0231.stanza import BitsOfBinary
15
12
  from sqlalchemy import Engine, delete, select, update
16
- from sqlalchemy.orm import Session, attributes, load_only
17
- from sqlalchemy.sql.functions import count
13
+ from sqlalchemy.exc import InvalidRequestError
14
+ from sqlalchemy.orm import Session, attributes, sessionmaker
18
15
 
19
16
  from ..core import config
20
17
  from ..util.archive_msg import HistoryMessage
21
- from ..util.types import (
22
- URL,
23
- CachedPresence,
24
- ClientType,
25
- MamMetadata,
26
- MucAffiliation,
27
- MucRole,
28
- Sticker,
29
- )
30
- from ..util.types import Hat as HatTuple
18
+ from ..util.types import MamMetadata, Sticker
31
19
  from .meta import Base
32
20
  from .models import (
33
21
  ArchivedMessage,
34
22
  ArchivedMessageSource,
35
- Attachment,
36
- Avatar,
37
23
  Bob,
38
24
  Contact,
39
25
  ContactSent,
26
+ DirectMessages,
27
+ DirectThreads,
40
28
  GatewayUser,
41
- Hat,
42
- LegacyIdsMulti,
29
+ GroupMessages,
30
+ GroupThreads,
43
31
  Participant,
44
32
  Room,
45
- XmppIdsMulti,
46
- XmppToLegacyEnum,
47
- XmppToLegacyIds,
48
- participant_hats,
49
33
  )
50
34
 
51
- if TYPE_CHECKING:
52
- from ..contact.contact import LegacyContact
53
- from ..group.participant import LegacyParticipant
54
- from ..group.room import LegacyMUC
55
35
 
36
+ class UpdatedMixin:
37
+ model: Type[Base] = NotImplemented
56
38
 
57
- class EngineMixin:
58
- def __init__(self, engine: Engine):
59
- self._engine = engine
39
+ def __init__(self, session: Session) -> None:
40
+ session.execute(update(self.model).values(updated=False))
60
41
 
61
- # TODO: we should not have a global Session object but instead build Sessions with different parameters
62
- # depending on the context (startup, incoming XMPP event, incoming legacy event).
63
- @contextmanager
64
- def session(self, **session_kwargs) -> Iterator[Session]:
65
- global _session
66
- if _session is not None:
67
- yield _session
68
- return
69
- with Session(self._engine, **session_kwargs) as session:
70
- _session = session
71
- try:
72
- yield session
73
- finally:
74
- _session = None
42
+ def get_by_pk(self, session: Session, pk: int) -> Type[Base]:
43
+ stmt = select(self.model).where(self.model.id == pk) # type:ignore
44
+ return session.scalar(stmt)
75
45
 
76
46
 
77
- class UpdatedMixin(EngineMixin):
78
- model: Type[Base] = NotImplemented
47
+ class SlidgeStore:
48
+ def __init__(self, engine: Engine) -> None:
49
+ self._engine = engine
50
+ self.session = sessionmaker(engine)
79
51
 
80
- def __init__(self, *a, **kw):
81
- super().__init__(*a, **kw)
52
+ self.users = UserStore(self.session)
53
+ self.avatars = AvatarStore(self.session)
54
+ self.id_map = IdMapStore()
55
+ self.bob = BobStore()
82
56
  with self.session() as session:
83
- session.execute(update(self.model).values(updated=False)) # type:ignore
57
+ self.contacts = ContactStore(session)
58
+ self.mam = MAMStore(session, self.session)
59
+ self.rooms = RoomStore(session)
60
+ self.participants = ParticipantStore(session)
84
61
  session.commit()
85
62
 
86
- def get_by_pk(self, pk: int) -> Optional[Base]:
87
- with self.session() as session:
88
- return session.execute(
89
- select(self.model).where(self.model.id == pk) # type:ignore
90
- ).scalar()
91
63
 
64
+ class UserStore:
65
+ def __init__(self, session_maker) -> None:
66
+ self.session = session_maker
92
67
 
93
- class SlidgeStore(EngineMixin):
94
- def __init__(self, engine: Engine):
95
- super().__init__(engine)
96
- self.users = UserStore(engine)
97
- self.avatars = AvatarStore(engine)
98
- self.contacts = ContactStore(engine)
99
- self.mam = MAMStore(engine)
100
- self.multi = MultiStore(engine)
101
- self.attachments = AttachmentStore(engine)
102
- self.rooms = RoomStore(engine)
103
- self.sent = SentStore(engine)
104
- self.participants = ParticipantStore(engine)
105
- self.bob = BobStore(engine)
106
-
107
-
108
- class UserStore(EngineMixin):
109
- def new(self, jid: JID, legacy_module_data: dict) -> GatewayUser:
110
- if jid.resource:
111
- jid = JID(jid.bare)
68
+ def update(self, user: GatewayUser) -> None:
112
69
  with self.session(expire_on_commit=False) as session:
113
- user = session.execute(
114
- select(GatewayUser).where(GatewayUser.jid == jid)
115
- ).scalar()
116
- if user is not None:
117
- return user
118
- user = GatewayUser(jid=jid, legacy_module_data=legacy_module_data)
119
- session.add(user)
120
- session.commit()
121
- return user
122
-
123
- def update(self, user: GatewayUser):
124
- # https://github.com/sqlalchemy/sqlalchemy/discussions/6473
125
- attributes.flag_modified(user, "legacy_module_data")
126
- attributes.flag_modified(user, "preferences")
127
- with self.session() as session:
70
+ # https://github.com/sqlalchemy/sqlalchemy/discussions/6473
71
+ try:
72
+ attributes.flag_modified(user, "legacy_module_data")
73
+ attributes.flag_modified(user, "preferences")
74
+ except InvalidRequestError:
75
+ pass
128
76
  session.add(user)
129
77
  session.commit()
130
78
 
131
- def get_all(self) -> Iterator[GatewayUser]:
132
- with self.session() as session:
133
- yield from session.execute(select(GatewayUser)).scalars()
134
-
135
- def get(self, jid: JID) -> Optional[GatewayUser]:
136
- with self.session() as session:
137
- return session.execute(
138
- select(GatewayUser).where(GatewayUser.jid == jid.bare)
139
- ).scalar()
140
-
141
- def get_by_stanza(self, stanza: Iq | Message | Presence) -> Optional[GatewayUser]:
142
- return self.get(stanza.get_from())
143
-
144
- def delete(self, jid: JID) -> None:
145
- with self.session() as session:
146
- session.delete(self.get(jid))
147
- session.commit()
148
-
149
- def set_avatar_hash(self, pk: int, h: str | None = None) -> None:
150
- with self.session() as session:
151
- session.execute(
152
- update(GatewayUser).where(GatewayUser.id == pk).values(avatar_hash=h)
153
- )
154
- session.commit()
155
-
156
-
157
- class AvatarStore(EngineMixin):
158
- def get_by_url(self, url: URL | str) -> Optional[Avatar]:
159
- with self.session() as session:
160
- return session.execute(select(Avatar).where(Avatar.url == url)).scalar()
161
79
 
162
- def get_by_pk(self, pk: int) -> Optional[Avatar]:
163
- with self.session() as session:
164
- return session.execute(select(Avatar).where(Avatar.id == pk)).scalar()
80
+ class AvatarStore:
81
+ def __init__(self, session_maker) -> None:
82
+ self.session = session_maker
165
83
 
166
- def delete_by_pk(self, pk: int):
167
- with self.session() as session:
168
- session.execute(delete(Avatar).where(Avatar.id == pk))
169
- session.commit()
170
84
 
171
- def get_all(self) -> Iterator[Avatar]:
172
- with self.session() as session:
173
- yield from session.execute(select(Avatar)).scalars()
85
+ LegacyToXmppType = (
86
+ Type[DirectMessages]
87
+ | Type[DirectThreads]
88
+ | Type[GroupMessages]
89
+ | Type[GroupThreads]
90
+ )
174
91
 
175
92
 
176
- class SentStore(EngineMixin):
177
- def set_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
178
- with self.session() as session:
179
- msg = (
180
- session.query(XmppToLegacyIds)
181
- .filter(XmppToLegacyIds.user_account_id == user_pk)
182
- .filter(XmppToLegacyIds.legacy_id == legacy_id)
183
- .filter(XmppToLegacyIds.xmpp_id == xmpp_id)
184
- .scalar()
93
+ class IdMapStore:
94
+ @staticmethod
95
+ def _set(
96
+ session: Session,
97
+ foreign_key: int,
98
+ legacy_id: str,
99
+ xmpp_ids: list[str],
100
+ type_: LegacyToXmppType,
101
+ ) -> None:
102
+ kwargs = dict(foreign_key=foreign_key, legacy_id=legacy_id)
103
+ ids = session.scalars(
104
+ select(type_.id).filter(
105
+ type_.foreign_key == foreign_key, type_.legacy_id == legacy_id
185
106
  )
186
- if msg is None:
187
- msg = XmppToLegacyIds(user_account_id=user_pk)
188
- else:
189
- log.debug("Resetting a DM from sent store")
190
- msg.legacy_id = legacy_id
191
- msg.xmpp_id = xmpp_id
192
- msg.type = XmppToLegacyEnum.DM
107
+ )
108
+ if ids:
109
+ log.debug("Resetting legacy ID %s", legacy_id)
110
+ session.execute(delete(type_).where(type_.id.in_(ids)))
111
+ for xmpp_id in xmpp_ids:
112
+ msg = type_(xmpp_id=xmpp_id, **kwargs)
193
113
  session.add(msg)
194
- session.commit()
195
114
 
196
- def get_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]:
197
- with self.session() as session:
198
- return session.execute(
199
- select(XmppToLegacyIds.xmpp_id)
200
- .where(XmppToLegacyIds.user_account_id == user_pk)
201
- .where(XmppToLegacyIds.legacy_id == legacy_id)
202
- .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM)
203
- ).scalar()
115
+ def set_thread(
116
+ self,
117
+ session: Session,
118
+ foreign_key: int,
119
+ legacy_id: str,
120
+ xmpp_id: str,
121
+ group: bool,
122
+ ) -> None:
123
+ self._set(
124
+ session,
125
+ foreign_key,
126
+ legacy_id,
127
+ [xmpp_id],
128
+ GroupThreads if group else DirectThreads,
129
+ )
204
130
 
205
- def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
206
- with self.session() as session:
207
- return session.execute(
208
- select(XmppToLegacyIds.legacy_id)
209
- .where(XmppToLegacyIds.user_account_id == user_pk)
210
- .where(XmppToLegacyIds.xmpp_id == xmpp_id)
211
- .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM)
212
- ).scalar()
131
+ def set_msg(
132
+ self,
133
+ session: Session,
134
+ foreign_key: int,
135
+ legacy_id: str,
136
+ xmpp_ids: list[str],
137
+ group: bool,
138
+ ) -> None:
139
+ self._set(
140
+ session,
141
+ foreign_key,
142
+ legacy_id,
143
+ xmpp_ids,
144
+ GroupMessages if group else DirectMessages,
145
+ )
213
146
 
214
- def set_group_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
215
- with self.session() as session:
216
- msg = XmppToLegacyIds(
217
- user_account_id=user_pk,
218
- legacy_id=legacy_id,
219
- xmpp_id=xmpp_id,
220
- type=XmppToLegacyEnum.GROUP_CHAT,
147
+ @staticmethod
148
+ def _get(
149
+ session: Session, foreign_key: int, legacy_id: str, type_: LegacyToXmppType
150
+ ) -> list[str]:
151
+ return list(
152
+ session.scalars(
153
+ select(type_.xmpp_id).filter_by(
154
+ foreign_key=foreign_key, legacy_id=legacy_id
155
+ )
221
156
  )
222
- session.add(msg)
223
- session.commit()
157
+ )
224
158
 
225
- def get_group_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]:
226
- with self.session() as session:
227
- return session.execute(
228
- select(XmppToLegacyIds.xmpp_id)
229
- .where(XmppToLegacyIds.user_account_id == user_pk)
230
- .where(XmppToLegacyIds.legacy_id == legacy_id)
231
- .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT)
232
- ).scalar()
159
+ def get_xmpp(
160
+ self, session: Session, foreign_key: int, legacy_id: str, group: bool
161
+ ) -> list[str]:
162
+ return self._get(
163
+ session,
164
+ foreign_key,
165
+ legacy_id,
166
+ GroupMessages if group else DirectMessages,
167
+ )
233
168
 
234
- def get_group_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
235
- with self.session() as session:
236
- return session.execute(
237
- select(XmppToLegacyIds.legacy_id)
238
- .where(XmppToLegacyIds.user_account_id == user_pk)
239
- .where(XmppToLegacyIds.xmpp_id == xmpp_id)
240
- .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT)
241
- ).scalar()
169
+ @staticmethod
170
+ def _get_legacy(
171
+ session: Session, foreign_key: int, xmpp_id: str, type_: LegacyToXmppType
172
+ ) -> Optional[str]:
173
+ return session.scalar(
174
+ select(type_.legacy_id).filter_by(foreign_key=foreign_key, xmpp_id=xmpp_id)
175
+ )
242
176
 
243
- def set_thread(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
244
- with self.session() as session:
245
- msg = XmppToLegacyIds(
246
- user_account_id=user_pk,
247
- legacy_id=legacy_id,
248
- xmpp_id=xmpp_id,
249
- type=XmppToLegacyEnum.THREAD,
250
- )
251
- session.add(msg)
252
- session.commit()
177
+ def get_legacy(
178
+ self, session: Session, foreign_key: int, xmpp_id: str, group: bool
179
+ ) -> Optional[str]:
180
+ return self._get_legacy(
181
+ session,
182
+ foreign_key,
183
+ xmpp_id,
184
+ GroupMessages if group else DirectMessages,
185
+ )
253
186
 
254
- def get_legacy_thread(self, user_pk: int, xmpp_id: str) -> Optional[str]:
255
- with self.session() as session:
256
- return session.execute(
257
- select(XmppToLegacyIds.legacy_id)
258
- .where(XmppToLegacyIds.user_account_id == user_pk)
259
- .where(XmppToLegacyIds.xmpp_id == xmpp_id)
260
- .where(XmppToLegacyIds.type == XmppToLegacyEnum.THREAD)
261
- ).scalar()
187
+ def get_thread(
188
+ self, session: Session, foreign_key: int, xmpp_id: str, group: bool
189
+ ) -> Optional[str]:
190
+ return self._get_legacy(
191
+ session,
192
+ foreign_key,
193
+ xmpp_id,
194
+ GroupThreads if group else DirectThreads,
195
+ )
262
196
 
263
- def was_sent_by_user(self, user_pk: int, legacy_id: str) -> bool:
264
- with self.session() as session:
265
- return (
266
- session.execute(
267
- select(XmppToLegacyIds.legacy_id)
268
- .where(XmppToLegacyIds.user_account_id == user_pk)
269
- .where(XmppToLegacyIds.legacy_id == legacy_id)
270
- ).scalar()
271
- is not None
272
- )
197
+ @staticmethod
198
+ def was_sent_by_user(
199
+ session: Session, foreign_key: int, legacy_id: str, group: bool
200
+ ) -> bool:
201
+ type_ = GroupMessages if group else DirectMessages
202
+ return (
203
+ session.scalar(
204
+ select(type_.id).filter_by(foreign_key=foreign_key, legacy_id=legacy_id)
205
+ )
206
+ is not None
207
+ )
273
208
 
274
209
 
275
210
  class ContactStore(UpdatedMixin):
276
211
  model = Contact
277
212
 
278
- def __init__(self, *a, **k):
279
- super().__init__(*a, **k)
280
- with self.session() as session:
281
- session.execute(update(Contact).values(cached_presence=False))
282
- session.commit()
283
-
284
- def get_all(self, user_pk: int) -> Iterator[Contact]:
285
- with self.session() as session:
286
- yield from session.execute(
287
- select(Contact).where(Contact.user_account_id == user_pk)
288
- ).scalars()
289
-
290
- def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Contact]:
291
- with self.session() as session:
292
- return session.execute(
293
- select(Contact)
294
- .where(Contact.jid == jid.bare)
295
- .where(Contact.user_account_id == user_pk)
296
- ).scalar()
297
-
298
- def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Contact]:
299
- with self.session() as session:
300
- return session.execute(
301
- select(Contact)
302
- .where(Contact.legacy_id == legacy_id)
303
- .where(Contact.user_account_id == user_pk)
304
- ).scalar()
305
-
306
- def update_nick(self, contact_pk: int, nick: Optional[str]) -> None:
307
- with self.session() as session:
308
- session.execute(
309
- update(Contact).where(Contact.id == contact_pk).values(nick=nick)
310
- )
311
- session.commit()
312
-
313
- def get_presence(self, contact_pk: int) -> Optional[CachedPresence]:
314
- with self.session() as session:
315
- presence = session.execute(
316
- select(
317
- Contact.last_seen,
318
- Contact.ptype,
319
- Contact.pstatus,
320
- Contact.pshow,
321
- Contact.cached_presence,
322
- ).where(Contact.id == contact_pk)
323
- ).first()
324
- if presence is None or not presence[-1]:
325
- return None
326
- return CachedPresence(*presence[:-1])
327
-
328
- def set_presence(self, contact_pk: int, presence: CachedPresence) -> None:
329
- with self.session() as session:
330
- session.execute(
331
- update(Contact)
332
- .where(Contact.id == contact_pk)
333
- .values(**presence._asdict(), cached_presence=True)
334
- )
335
- session.commit()
336
-
337
- def reset_presence(self, contact_pk: int):
338
- with self.session() as session:
339
- session.execute(
340
- update(Contact)
341
- .where(Contact.id == contact_pk)
342
- .values(
343
- last_seen=None,
344
- ptype=None,
345
- pstatus=None,
346
- pshow=None,
347
- cached_presence=False,
348
- )
349
- )
350
- session.commit()
351
-
352
- def set_avatar(
353
- self, contact_pk: int, avatar_pk: Optional[int], avatar_legacy_id: Optional[str]
354
- ):
355
- with self.session() as session:
356
- session.execute(
357
- update(Contact)
358
- .where(Contact.id == contact_pk)
359
- .values(avatar_id=avatar_pk, avatar_legacy_id=avatar_legacy_id)
360
- )
361
- session.commit()
362
-
363
- def get_avatar_legacy_id(self, contact_pk: int) -> Optional[str]:
364
- with self.session() as session:
365
- contact = session.execute(
366
- select(Contact).where(Contact.id == contact_pk)
367
- ).scalar()
368
- if contact is None or contact.avatar is None:
369
- return None
370
- return contact.avatar_legacy_id
371
-
372
- def update(self, contact: "LegacyContact", commit=True) -> int:
373
- with self.session() as session:
374
- if contact.contact_pk is None:
375
- if contact.cached_presence is not None:
376
- presence_kwargs = contact.cached_presence._asdict()
377
- presence_kwargs["cached_presence"] = True
378
- else:
379
- presence_kwargs = {}
380
- row = Contact(
381
- jid=contact.jid.bare,
382
- legacy_id=str(contact.legacy_id),
383
- user_account_id=contact.user_pk,
384
- **presence_kwargs,
385
- )
386
- else:
387
- row = (
388
- session.query(Contact)
389
- .filter(Contact.id == contact.contact_pk)
390
- .one()
391
- )
392
- row.nick = contact.name
393
- row.is_friend = contact.is_friend
394
- row.added_to_roster = contact.added_to_roster
395
- row.updated = True
396
- row.extra_attributes = contact.serialize_extra_attributes()
397
- row.caps_ver = contact._caps_ver
398
- row.vcard = contact._vcard
399
- row.vcard_fetched = contact._vcard_fetched
400
- row.client_type = contact.client_type
401
- session.add(row)
402
- if commit:
403
- session.commit()
404
- return row.id
405
-
406
- def set_vcard(self, contact_pk: int, vcard: str | None) -> None:
407
- with self.session() as session:
408
- session.execute(
409
- update(Contact)
410
- .where(Contact.id == contact_pk)
411
- .values(vcard=vcard, vcard_fetched=True)
412
- )
413
- session.commit()
213
+ def __init__(self, session: Session) -> None:
214
+ super().__init__(session)
215
+ session.execute(update(Contact).values(cached_presence=False))
414
216
 
415
- def add_to_sent(self, contact_pk: int, msg_id: str) -> None:
416
- with self.session() as session:
417
- if (
418
- session.query(ContactSent.id)
419
- .where(ContactSent.contact_id == contact_pk)
420
- .where(ContactSent.msg_id == msg_id)
421
- .first()
422
- ) is not None:
423
- log.warning(
424
- "Contact %s has already sent message %s", contact_pk, msg_id
425
- )
426
- return
427
- new = ContactSent(contact_id=contact_pk, msg_id=msg_id)
428
- session.add(new)
429
- session.commit()
217
+ @staticmethod
218
+ def add_to_sent(session: Session, contact_pk: int, msg_id: str) -> None:
219
+ if (
220
+ session.query(ContactSent.id)
221
+ .where(ContactSent.contact_id == contact_pk)
222
+ .where(ContactSent.msg_id == msg_id)
223
+ .first()
224
+ ) is not None:
225
+ log.warning("Contact %s has already sent message %s", contact_pk, msg_id)
226
+ return
227
+ new = ContactSent(contact_id=contact_pk, msg_id=msg_id)
228
+ session.add(new)
430
229
 
431
- def pop_sent_up_to(self, contact_pk: int, msg_id: str) -> list[str]:
230
+ @staticmethod
231
+ def pop_sent_up_to(session: Session, contact_pk: int, msg_id: str) -> list[str]:
432
232
  result = []
433
233
  to_del = []
434
- with self.session() as session:
435
- for row in session.execute(
436
- select(ContactSent)
437
- .where(ContactSent.contact_id == contact_pk)
438
- .order_by(ContactSent.id)
439
- ).scalars():
440
- to_del.append(row.id)
441
- result.append(row.msg_id)
442
- if row.msg_id == msg_id:
443
- break
444
- for row_id in to_del:
445
- session.execute(delete(ContactSent).where(ContactSent.id == row_id))
446
- session.commit()
234
+ for row in session.execute(
235
+ select(ContactSent)
236
+ .where(ContactSent.contact_id == contact_pk)
237
+ .order_by(ContactSent.id)
238
+ ).scalars():
239
+ to_del.append(row.id)
240
+ result.append(row.msg_id)
241
+ if row.msg_id == msg_id:
242
+ break
243
+ session.execute(delete(ContactSent).where(ContactSent.id.in_(to_del)))
447
244
  return result
448
245
 
449
- def set_friend(self, contact_pk: int, is_friend: bool) -> None:
450
- with self.session() as session:
451
- session.execute(
452
- update(Contact)
453
- .where(Contact.id == contact_pk)
454
- .values(is_friend=is_friend)
455
- )
456
- session.commit()
457
-
458
- def set_added_to_roster(self, contact_pk: int, value: bool) -> None:
459
- with self.session() as session:
460
- session.execute(
461
- update(Contact)
462
- .where(Contact.id == contact_pk)
463
- .values(added_to_roster=value)
464
- )
465
- session.commit()
466
-
467
- def delete(self, contact_pk: int) -> None:
468
- with self.session() as session:
469
- session.execute(delete(Contact).where(Contact.id == contact_pk))
470
- session.commit()
471
-
472
- def set_client_type(self, contact_pk: int, value: ClientType):
473
- with self.session() as session:
474
- session.execute(
475
- update(Contact)
476
- .where(Contact.id == contact_pk)
477
- .values(client_type=value)
478
- )
479
- session.commit()
480
-
481
246
 
482
- class MAMStore(EngineMixin):
483
- def __init__(self, *a, **kw):
484
- super().__init__(*a, **kw)
485
- with self.session() as session:
486
- session.execute(
487
- update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL)
488
- )
489
- session.commit()
247
+ class MAMStore:
248
+ def __init__(self, session: Session, session_maker) -> None:
249
+ self.session = session_maker
250
+ session.execute(
251
+ update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL)
252
+ )
490
253
 
491
- def nuke_older_than(self, days: int) -> None:
492
- with self.session() as session:
493
- session.execute(
494
- delete(ArchivedMessage).where(
495
- ArchivedMessage.timestamp < datetime.now() - timedelta(days=days)
496
- )
254
+ @staticmethod
255
+ def nuke_older_than(session: Session, days: int) -> None:
256
+ session.execute(
257
+ delete(ArchivedMessage).where(
258
+ ArchivedMessage.timestamp < datetime.now() - timedelta(days=days)
497
259
  )
498
- session.commit()
260
+ )
499
261
 
262
+ @staticmethod
500
263
  def add_message(
501
- self,
264
+ session: Session,
502
265
  room_pk: int,
503
266
  message: HistoryMessage,
504
267
  archive_only: bool,
505
- legacy_msg_id: str | None,
268
+ legacy_msg_id: Optional[str],
506
269
  ) -> None:
507
- with self.session() as session:
508
- source = (
509
- ArchivedMessageSource.BACKFILL
510
- if archive_only
511
- else ArchivedMessageSource.LIVE
512
- )
270
+ source = (
271
+ ArchivedMessageSource.BACKFILL
272
+ if archive_only
273
+ else ArchivedMessageSource.LIVE
274
+ )
275
+ existing = session.execute(
276
+ select(ArchivedMessage)
277
+ .where(ArchivedMessage.room_id == room_pk)
278
+ .where(ArchivedMessage.stanza_id == message.id)
279
+ ).scalar()
280
+ if existing is None and legacy_msg_id is not None:
513
281
  existing = session.execute(
514
282
  select(ArchivedMessage)
515
283
  .where(ArchivedMessage.room_id == room_pk)
516
- .where(ArchivedMessage.stanza_id == message.id)
284
+ .where(ArchivedMessage.legacy_id == legacy_msg_id)
517
285
  ).scalar()
518
- if existing is None and legacy_msg_id is not None:
519
- existing = session.execute(
520
- select(ArchivedMessage)
521
- .where(ArchivedMessage.room_id == room_pk)
522
- .where(ArchivedMessage.legacy_id == legacy_msg_id)
523
- ).scalar()
524
- if existing is not None:
525
- log.debug("Updating message %s in room %s", message.id, room_pk)
526
- existing.timestamp = message.when
527
- existing.stanza = str(message.stanza)
528
- existing.author_jid = message.stanza.get_from()
529
- existing.source = source
530
- existing.legacy_id = legacy_msg_id
531
- session.add(existing)
532
- session.commit()
533
- return
534
- mam_msg = ArchivedMessage(
535
- stanza_id=message.id,
536
- timestamp=message.when,
537
- stanza=str(message.stanza),
538
- author_jid=message.stanza.get_from(),
539
- room_id=room_pk,
540
- source=source,
541
- legacy_id=legacy_msg_id,
542
- )
543
- session.add(mam_msg)
544
- session.commit()
286
+ if existing is not None:
287
+ log.debug("Updating message %s in room %s", message.id, room_pk)
288
+ existing.timestamp = message.when
289
+ existing.stanza = str(message.stanza)
290
+ existing.author_jid = message.stanza.get_from()
291
+ existing.source = source
292
+ existing.legacy_id = legacy_msg_id
293
+ session.add(existing)
294
+ return
295
+ mam_msg = ArchivedMessage(
296
+ stanza_id=message.id,
297
+ timestamp=message.when,
298
+ stanza=str(message.stanza),
299
+ author_jid=message.stanza.get_from(),
300
+ room_id=room_pk,
301
+ source=source,
302
+ legacy_id=legacy_msg_id,
303
+ )
304
+ session.add(mam_msg)
545
305
 
306
+ @staticmethod
546
307
  def get_messages(
547
- self,
308
+ session: Session,
548
309
  room_pk: int,
549
310
  start_date: Optional[datetime] = None,
550
311
  end_date: Optional[datetime] = None,
@@ -553,612 +314,177 @@ class MAMStore(EngineMixin):
553
314
  ids: Collection[str] = (),
554
315
  last_page_n: Optional[int] = None,
555
316
  sender: Optional[str] = None,
556
- flip=False,
317
+ flip: bool = False,
557
318
  ) -> Iterator[HistoryMessage]:
558
- with self.session() as session:
559
- q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
560
- if start_date is not None:
561
- q = q.where(ArchivedMessage.timestamp >= start_date)
562
- if end_date is not None:
563
- q = q.where(ArchivedMessage.timestamp <= end_date)
564
- if before_id is not None:
565
- stamp = session.execute(
566
- select(ArchivedMessage.timestamp).where(
567
- ArchivedMessage.stanza_id == before_id
568
- )
569
- ).scalar()
570
- if stamp is None:
571
- raise XMPPError(
572
- "item-not-found",
573
- f"Message {before_id} not found",
574
- )
575
- q = q.where(ArchivedMessage.timestamp < stamp)
576
- if after_id is not None:
577
- stamp = session.execute(
578
- select(ArchivedMessage.timestamp).where(
579
- ArchivedMessage.stanza_id == after_id
580
- )
581
- ).scalar()
582
- if stamp is None:
583
- raise XMPPError(
584
- "item-not-found",
585
- f"Message {after_id} not found",
586
- )
587
- q = q.where(ArchivedMessage.timestamp > stamp)
588
- if ids:
589
- q = q.filter(ArchivedMessage.stanza_id.in_(ids))
590
- if sender is not None:
591
- q = q.where(ArchivedMessage.author_jid == sender)
592
- if flip:
593
- q = q.order_by(ArchivedMessage.timestamp.desc())
594
- else:
595
- q = q.order_by(ArchivedMessage.timestamp.asc())
596
- msgs = list(session.execute(q).scalars())
597
- if ids and len(msgs) != len(ids):
319
+ q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
320
+ if start_date is not None:
321
+ q = q.where(ArchivedMessage.timestamp >= start_date)
322
+ if end_date is not None:
323
+ q = q.where(ArchivedMessage.timestamp <= end_date)
324
+ if before_id is not None:
325
+ stamp = session.execute(
326
+ select(ArchivedMessage.timestamp).where(
327
+ ArchivedMessage.stanza_id == before_id
328
+ )
329
+ ).scalar()
330
+ if stamp is None:
598
331
  raise XMPPError(
599
332
  "item-not-found",
600
- "One of the requested messages IDs could not be found "
601
- "with the given constraints.",
333
+ f"Message {before_id} not found",
602
334
  )
603
- if last_page_n is not None:
604
- if flip:
605
- msgs = msgs[:last_page_n]
606
- else:
607
- msgs = msgs[-last_page_n:]
608
- for h in msgs:
609
- yield HistoryMessage(
610
- stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=timezone.utc)
335
+ q = q.where(ArchivedMessage.timestamp < stamp)
336
+ if after_id is not None:
337
+ stamp = session.execute(
338
+ select(ArchivedMessage.timestamp).where(
339
+ ArchivedMessage.stanza_id == after_id
611
340
  )
612
-
613
- def get_first(self, room_pk: int, with_legacy_id=False) -> ArchivedMessage | None:
614
- with self.session() as session:
615
- q = (
616
- select(ArchivedMessage)
617
- .where(ArchivedMessage.room_id == room_pk)
618
- .order_by(ArchivedMessage.timestamp.asc())
341
+ ).scalar()
342
+ if stamp is None:
343
+ raise XMPPError(
344
+ "item-not-found",
345
+ f"Message {after_id} not found",
346
+ )
347
+ q = q.where(ArchivedMessage.timestamp > stamp)
348
+ if ids:
349
+ q = q.filter(ArchivedMessage.stanza_id.in_(ids))
350
+ if sender is not None:
351
+ q = q.where(ArchivedMessage.author_jid == sender)
352
+ if flip:
353
+ q = q.order_by(ArchivedMessage.timestamp.desc())
354
+ else:
355
+ q = q.order_by(ArchivedMessage.timestamp.asc())
356
+ msgs = list(session.execute(q).scalars())
357
+ if ids and len(msgs) != len(ids):
358
+ raise XMPPError(
359
+ "item-not-found",
360
+ "One of the requested messages IDs could not be found "
361
+ "with the given constraints.",
362
+ )
363
+ if last_page_n is not None:
364
+ if flip:
365
+ msgs = msgs[:last_page_n]
366
+ else:
367
+ msgs = msgs[-last_page_n:]
368
+ for h in msgs:
369
+ yield HistoryMessage(
370
+ stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=timezone.utc)
619
371
  )
620
- if with_legacy_id:
621
- q = q.filter(ArchivedMessage.legacy_id.isnot(None))
622
- return session.execute(q).scalar()
623
372
 
373
+ @staticmethod
374
+ def get_first(
375
+ session: Session, room_pk: int, with_legacy_id: bool = False
376
+ ) -> Optional[ArchivedMessage]:
377
+ q = (
378
+ select(ArchivedMessage)
379
+ .where(ArchivedMessage.room_id == room_pk)
380
+ .order_by(ArchivedMessage.timestamp.asc())
381
+ )
382
+ if with_legacy_id:
383
+ q = q.filter(ArchivedMessage.legacy_id.isnot(None))
384
+ return session.execute(q).scalar()
385
+
386
+ @staticmethod
624
387
  def get_last(
625
- self, room_pk: int, source: ArchivedMessageSource | None = None
626
- ) -> ArchivedMessage | None:
627
- with self.session() as session:
628
- q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
388
+ session: Session, room_pk: int, source: Optional[ArchivedMessageSource] = None
389
+ ) -> Optional[ArchivedMessage]:
390
+ q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
629
391
 
630
- if source is not None:
631
- q = q.where(ArchivedMessage.source == source)
392
+ if source is not None:
393
+ q = q.where(ArchivedMessage.source == source)
632
394
 
633
- return session.execute(
634
- q.order_by(ArchivedMessage.timestamp.desc())
635
- ).scalar()
395
+ return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar()
636
396
 
637
- def get_first_and_last(self, room_pk: int) -> list[MamMetadata]:
397
+ def get_first_and_last(self, session: Session, room_pk: int) -> list[MamMetadata]:
638
398
  r = []
639
- with self.session():
640
- first = self.get_first(room_pk)
641
- if first is not None:
642
- r.append(MamMetadata(first.stanza_id, first.timestamp))
643
- last = self.get_last(room_pk)
644
- if last is not None:
645
- r.append(MamMetadata(last.stanza_id, last.timestamp))
399
+ first = self.get_first(session, room_pk)
400
+ if first is not None:
401
+ r.append(MamMetadata(first.stanza_id, first.timestamp))
402
+ last = self.get_last(session, room_pk)
403
+ if last is not None:
404
+ r.append(MamMetadata(last.stanza_id, last.timestamp))
646
405
  return r
647
406
 
407
+ @staticmethod
648
408
  def get_most_recent_with_legacy_id(
649
- self, room_pk: int, source: ArchivedMessageSource | None = None
650
- ) -> ArchivedMessage | None:
651
- with self.session() as session:
652
- q = (
653
- select(ArchivedMessage)
654
- .where(ArchivedMessage.room_id == room_pk)
655
- .where(ArchivedMessage.legacy_id.isnot(None))
656
- )
657
- if source is not None:
658
- q = q.where(ArchivedMessage.source == source)
659
- return session.execute(
660
- q.order_by(ArchivedMessage.timestamp.desc())
661
- ).scalar()
409
+ session: Session, room_pk: int, source: Optional[ArchivedMessageSource] = None
410
+ ) -> Optional[ArchivedMessage]:
411
+ q = (
412
+ select(ArchivedMessage)
413
+ .where(ArchivedMessage.room_id == room_pk)
414
+ .where(ArchivedMessage.legacy_id.isnot(None))
415
+ )
416
+ if source is not None:
417
+ q = q.where(ArchivedMessage.source == source)
418
+ return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar()
662
419
 
420
+ @staticmethod
663
421
  def get_least_recent_with_legacy_id_after(
664
- self, room_pk: int, after_id: str, source=ArchivedMessageSource.LIVE
665
- ) -> ArchivedMessage | None:
666
- with self.session() as session:
667
- after_timestamp = (
668
- session.query(ArchivedMessage.timestamp)
669
- .filter(ArchivedMessage.room_id == room_pk)
670
- .filter(ArchivedMessage.legacy_id == after_id)
671
- .scalar()
672
- )
673
- q = (
674
- select(ArchivedMessage)
675
- .where(ArchivedMessage.room_id == room_pk)
676
- .where(ArchivedMessage.legacy_id.isnot(None))
677
- .where(ArchivedMessage.source == source)
678
- .where(ArchivedMessage.timestamp > after_timestamp)
679
- )
680
- return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar()
681
-
682
- def get_by_legacy_id(self, room_pk: int, legacy_id: str) -> ArchivedMessage | None:
683
- with self.session() as session:
684
- return (
685
- session.query(ArchivedMessage)
686
- .filter(ArchivedMessage.room_id == room_pk)
687
- .filter(ArchivedMessage.legacy_id == legacy_id)
688
- .first()
689
- )
690
-
691
-
692
- class MultiStore(EngineMixin):
693
- def get_xmpp_ids(self, user_pk: int, xmpp_id: str) -> list[str]:
694
- with self.session() as session:
695
- multi = session.execute(
696
- select(XmppIdsMulti)
697
- .where(XmppIdsMulti.xmpp_id == xmpp_id)
698
- .where(XmppIdsMulti.user_account_id == user_pk)
699
- ).scalar()
700
- if multi is None:
701
- return []
702
- if multi.legacy_ids_multi is None:
703
- return []
704
- return [m.xmpp_id for m in multi.legacy_ids_multi.xmpp_ids]
705
-
706
- def set_xmpp_ids(
707
- self, user_pk: int, legacy_msg_id: str, xmpp_ids: list[str], fail=False
708
- ) -> None:
709
- with self.session() as session:
710
- existing = session.execute(
711
- select(LegacyIdsMulti)
712
- .where(LegacyIdsMulti.user_account_id == user_pk)
713
- .where(LegacyIdsMulti.legacy_id == legacy_msg_id)
714
- ).scalar()
715
- if existing is not None:
716
- if fail:
717
- raise
718
- log.debug("Resetting multi for %s", legacy_msg_id)
719
- session.execute(
720
- delete(LegacyIdsMulti)
721
- .where(LegacyIdsMulti.user_account_id == user_pk)
722
- .where(LegacyIdsMulti.legacy_id == legacy_msg_id)
723
- )
724
- for i in xmpp_ids:
725
- session.execute(
726
- delete(XmppIdsMulti)
727
- .where(XmppIdsMulti.user_account_id == user_pk)
728
- .where(XmppIdsMulti.xmpp_id == i)
729
- )
730
- session.commit()
731
- self.set_xmpp_ids(user_pk, legacy_msg_id, xmpp_ids, True)
732
- return
733
-
734
- row = LegacyIdsMulti(
735
- user_account_id=user_pk,
736
- legacy_id=legacy_msg_id,
737
- xmpp_ids=[
738
- XmppIdsMulti(user_account_id=user_pk, xmpp_id=i)
739
- for i in xmpp_ids
740
- if i
741
- ],
742
- )
743
- session.add(row)
744
- session.commit()
745
-
746
- def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
747
- with self.session() as session:
748
- multi = session.execute(
749
- select(XmppIdsMulti)
750
- .where(XmppIdsMulti.xmpp_id == xmpp_id)
751
- .where(XmppIdsMulti.user_account_id == user_pk)
752
- ).scalar()
753
- if multi is None:
754
- return None
755
- if multi.legacy_ids_multi is None:
756
- return None
757
- return multi.legacy_ids_multi.legacy_id
758
-
759
-
760
- class AttachmentStore(EngineMixin):
761
- def get_url(self, legacy_file_id: str) -> Optional[str]:
762
- with self.session() as session:
763
- return session.execute(
764
- select(Attachment.url).where(
765
- Attachment.legacy_file_id == legacy_file_id
766
- )
767
- ).scalar()
768
-
769
- def set_url(self, user_pk: int, legacy_file_id: str, url: str) -> None:
770
- with self.session() as session:
771
- att = session.execute(
772
- select(Attachment)
773
- .where(Attachment.legacy_file_id == legacy_file_id)
774
- .where(Attachment.user_account_id == user_pk)
775
- ).scalar()
776
- if att is None:
777
- att = Attachment(
778
- legacy_file_id=legacy_file_id, url=url, user_account_id=user_pk
779
- )
780
- session.add(att)
781
- else:
782
- att.url = url
783
- session.commit()
784
-
785
- def get_sims(self, url: str) -> Optional[str]:
786
- with self.session() as session:
787
- return session.execute(
788
- select(Attachment.sims).where(Attachment.url == url)
789
- ).scalar()
790
-
791
- def set_sims(self, url: str, sims: str) -> None:
792
- with self.session() as session:
793
- session.execute(
794
- update(Attachment).where(Attachment.url == url).values(sims=sims)
795
- )
796
- session.commit()
797
-
798
- def get_sfs(self, url: str) -> Optional[str]:
799
- with self.session() as session:
800
- return session.execute(
801
- select(Attachment.sfs).where(Attachment.url == url)
802
- ).scalar()
803
-
804
- def set_sfs(self, url: str, sfs: str) -> None:
805
- with self.session() as session:
806
- session.execute(
807
- update(Attachment).where(Attachment.url == url).values(sfs=sfs)
808
- )
809
- session.commit()
422
+ session: Session,
423
+ room_pk: int,
424
+ after_id: str,
425
+ source: ArchivedMessageSource = ArchivedMessageSource.LIVE,
426
+ ) -> Optional[ArchivedMessage]:
427
+ after_timestamp = (
428
+ session.query(ArchivedMessage.timestamp)
429
+ .filter(ArchivedMessage.room_id == room_pk)
430
+ .filter(ArchivedMessage.legacy_id == after_id)
431
+ .scalar()
432
+ )
433
+ q = (
434
+ select(ArchivedMessage)
435
+ .where(ArchivedMessage.room_id == room_pk)
436
+ .where(ArchivedMessage.legacy_id.isnot(None))
437
+ .where(ArchivedMessage.source == source)
438
+ .where(ArchivedMessage.timestamp > after_timestamp)
439
+ )
440
+ return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar()
810
441
 
811
- def remove(self, legacy_file_id: str) -> None:
812
- with self.session() as session:
813
- session.execute(
814
- delete(Attachment).where(Attachment.legacy_file_id == legacy_file_id)
815
- )
816
- session.commit()
442
+ @staticmethod
443
+ def get_by_legacy_id(
444
+ session: Session, room_pk: int, legacy_id: str
445
+ ) -> Optional[ArchivedMessage]:
446
+ return (
447
+ session.query(ArchivedMessage)
448
+ .filter(ArchivedMessage.room_id == room_pk)
449
+ .filter(ArchivedMessage.legacy_id == legacy_id)
450
+ .first()
451
+ )
817
452
 
818
453
 
819
454
  class RoomStore(UpdatedMixin):
820
455
  model = Room
821
456
 
822
- def __init__(self, *a, **kw):
823
- super().__init__(*a, **kw)
824
- with self.session() as session:
825
- session.execute(
826
- update(Room).values(
827
- subject_setter=None,
828
- user_resources=None,
829
- history_filled=False,
830
- participants_filled=False,
831
- )
832
- )
833
- session.commit()
834
-
835
- def set_avatar(
836
- self, room_pk: int, avatar_pk: int | None, avatar_legacy_id: str | None
837
- ) -> None:
838
- with self.session() as session:
839
- session.execute(
840
- update(Room)
841
- .where(Room.id == room_pk)
842
- .values(avatar_id=avatar_pk, avatar_legacy_id=avatar_legacy_id)
843
- )
844
- session.commit()
845
-
846
- def get_avatar_legacy_id(self, room_pk: int) -> Optional[str]:
847
- with self.session() as session:
848
- room = session.execute(select(Room).where(Room.id == room_pk)).scalar()
849
- if room is None or room.avatar is None:
850
- return None
851
- return room.avatar_legacy_id
852
-
853
- def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Room]:
854
- if jid.resource:
855
- raise TypeError
856
- with self.session() as session:
857
- return session.execute(
858
- select(Room)
859
- .where(Room.user_account_id == user_pk)
860
- .where(Room.jid == jid)
861
- ).scalar()
862
-
863
- def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Room]:
864
- with self.session() as session:
865
- return session.execute(
866
- select(Room)
867
- .where(Room.user_account_id == user_pk)
868
- .where(Room.legacy_id == legacy_id)
869
- ).scalar()
870
-
871
- def update_subject_setter(self, room_pk: int, subject_setter: str | None):
872
- with self.session() as session:
873
- session.execute(
874
- update(Room)
875
- .where(Room.id == room_pk)
876
- .values(subject_setter=subject_setter)
877
- )
878
- session.commit()
879
-
880
- def update(self, muc: "LegacyMUC") -> int:
881
- with self.session() as session:
882
- if muc.pk is None:
883
- row = Room(
884
- jid=muc.jid,
885
- legacy_id=str(muc.legacy_id),
886
- user_account_id=muc.user_pk,
887
- )
888
- else:
889
- row = session.query(Room).filter(Room.id == muc.pk).one()
890
-
891
- row.updated = True
892
- row.extra_attributes = muc.serialize_extra_attributes()
893
- row.name = muc.name
894
- row.description = muc.description
895
- row.user_resources = (
896
- None
897
- if not muc._user_resources
898
- else json.dumps(list(muc._user_resources))
899
- )
900
- row.muc_type = muc.type
901
- row.subject = muc.subject
902
- row.subject_date = muc.subject_date
903
- row.subject_setter = muc.subject_setter
904
- row.participants_filled = muc._participants_filled
905
- row.n_participants = muc._n_participants
906
- row.user_nick = muc.user_nick
907
- session.add(row)
908
- session.commit()
909
- return row.id
910
-
911
- def update_subject_date(
912
- self, room_pk: int, subject_date: Optional[datetime]
913
- ) -> None:
914
- with self.session() as session:
915
- session.execute(
916
- update(Room).where(Room.id == room_pk).values(subject_date=subject_date)
917
- )
918
- session.commit()
919
-
920
- def update_subject(self, room_pk: int, subject: Optional[str]) -> None:
921
- with self.session() as session:
922
- session.execute(
923
- update(Room).where(Room.id == room_pk).values(subject=subject)
924
- )
925
- session.commit()
926
-
927
- def update_description(self, room_pk: int, desc: Optional[str]) -> None:
928
- with self.session() as session:
929
- session.execute(
930
- update(Room).where(Room.id == room_pk).values(description=desc)
931
- )
932
- session.commit()
933
-
934
- def update_name(self, room_pk: int, name: Optional[str]) -> None:
935
- with self.session() as session:
936
- session.execute(update(Room).where(Room.id == room_pk).values(name=name))
937
- session.commit()
938
-
939
- def update_n_participants(self, room_pk: int, n: Optional[int]) -> None:
940
- with self.session() as session:
941
- session.execute(
942
- update(Room).where(Room.id == room_pk).values(n_participants=n)
457
+ def __init__(self, session: Session) -> None:
458
+ super().__init__(session)
459
+ session.execute(
460
+ update(Room).values(
461
+ subject_setter=None,
462
+ user_resources=None,
463
+ history_filled=False,
464
+ participants_filled=False,
943
465
  )
944
- session.commit()
945
-
946
- def update_user_nick(self, room_pk, nick: str) -> None:
947
- with self.session() as session:
948
- session.execute(
949
- update(Room).where(Room.id == room_pk).values(user_nick=nick)
950
- )
951
- session.commit()
952
-
953
- def delete(self, room_pk: int) -> None:
954
- with self.session() as session:
955
- session.execute(delete(Room).where(Room.id == room_pk))
956
- session.execute(delete(Participant).where(Participant.room_id == room_pk))
957
- session.commit()
958
-
959
- def set_resource(self, room_pk: int, resources: set[str]) -> None:
960
- with self.session() as session:
961
- session.execute(
962
- update(Room)
963
- .where(Room.id == room_pk)
964
- .values(
965
- user_resources=(
966
- None if not resources else json.dumps(list(resources))
967
- )
968
- )
969
- )
970
- session.commit()
971
-
972
- def nickname_is_available(self, room_pk: int, nickname: str) -> bool:
973
- with self.session() as session:
974
- return (
975
- session.execute(
976
- select(Participant)
977
- .where(Participant.room_id == room_pk)
978
- .where(Participant.nickname == nickname)
979
- ).scalar()
980
- is None
981
- )
982
-
983
- def set_participants_filled(self, room_pk: int, val=True) -> None:
984
- with self.session() as session:
985
- session.execute(
986
- update(Room).where(Room.id == room_pk).values(participants_filled=val)
987
- )
988
- session.commit()
989
-
990
- def set_history_filled(self, room_pk: int, val=True) -> None:
991
- with self.session() as session:
992
- session.execute(
993
- update(Room).where(Room.id == room_pk).values(history_filled=True)
994
- )
995
- session.commit()
996
-
997
- def get_all(self, user_pk: int) -> Iterator[Room]:
998
- with self.session() as session:
999
- yield from session.execute(
1000
- select(Room).where(Room.user_account_id == user_pk)
1001
- ).scalars()
1002
-
1003
- def get_all_jid_and_names(self, user_pk: int) -> Iterator[Room]:
1004
- with self.session() as session:
1005
- yield from session.scalars(
1006
- select(Room)
1007
- .filter(Room.user_account_id == user_pk)
1008
- .options(load_only(Room.jid, Room.name))
1009
- .order_by(Room.name)
1010
- ).all()
1011
-
1012
-
1013
- class ParticipantStore(EngineMixin):
1014
- def __init__(self, *a, **kw):
1015
- super().__init__(*a, **kw)
1016
- with self.session() as session:
1017
- session.execute(delete(participant_hats))
1018
- session.execute(delete(Hat))
1019
- session.execute(delete(Participant))
1020
- session.commit()
1021
-
1022
- def add(self, room_pk: int, nickname: str) -> int:
1023
- with self.session() as session:
1024
- existing = session.execute(
1025
- select(Participant.id)
1026
- .where(Participant.room_id == room_pk)
1027
- .where(Participant.nickname == nickname)
1028
- ).scalar()
1029
- if existing is not None:
1030
- return existing
1031
- participant = Participant(room_id=room_pk, nickname=nickname)
1032
- session.add(participant)
1033
- session.commit()
1034
- return participant.id
1035
-
1036
- def get_by_nickname(self, room_pk: int, nickname: str) -> Optional[Participant]:
1037
- with self.session() as session:
1038
- return session.execute(
1039
- select(Participant)
1040
- .where(Participant.room_id == room_pk)
1041
- .where(Participant.nickname == nickname)
1042
- ).scalar()
1043
-
1044
- def get_by_resource(self, room_pk: int, resource: str) -> Optional[Participant]:
1045
- with self.session() as session:
1046
- return session.execute(
1047
- select(Participant)
1048
- .where(Participant.room_id == room_pk)
1049
- .where(Participant.resource == resource)
1050
- ).scalar()
1051
-
1052
- def get_by_contact(self, room_pk: int, contact_pk: int) -> Optional[Participant]:
1053
- with self.session() as session:
1054
- return session.execute(
1055
- select(Participant)
1056
- .where(Participant.room_id == room_pk)
1057
- .where(Participant.contact_id == contact_pk)
1058
- ).scalar()
1059
-
1060
- def get_all(self, room_pk: int, user_included=True) -> Iterator[Participant]:
1061
- with self.session() as session:
1062
- q = select(Participant).where(Participant.room_id == room_pk)
1063
- if not user_included:
1064
- q = q.where(~Participant.is_user)
1065
- yield from session.execute(q).scalars()
1066
-
1067
- def get_for_contact(self, contact_pk: int) -> Iterator[Participant]:
1068
- with self.session() as session:
1069
- yield from session.execute(
1070
- select(Participant).where(Participant.contact_id == contact_pk)
1071
- ).scalars()
1072
-
1073
- def update(self, participant: "LegacyParticipant") -> None:
1074
- with self.session() as session:
1075
- session.execute(
1076
- update(Participant)
1077
- .where(Participant.id == participant.pk)
1078
- .values(
1079
- nickname=participant.nickname,
1080
- resource=participant.jid.resource,
1081
- nickname_no_illegal=participant._nickname_no_illegal,
1082
- affiliation=participant.affiliation,
1083
- role=participant.role,
1084
- presence_sent=participant._presence_sent, # type:ignore
1085
- # hats=[self.add_hat(h.uri, h.title) for h in participant._hats],
1086
- is_user=participant.is_user,
1087
- contact_id=(
1088
- None
1089
- if participant.contact is None
1090
- else participant.contact.contact_pk
1091
- ),
1092
- )
1093
- )
1094
- session.commit()
1095
-
1096
- def add_hat(self, uri: str, title: str) -> Hat:
1097
- with self.session() as session:
1098
- existing = session.execute(
1099
- select(Hat).where(Hat.uri == uri).where(Hat.title == title)
1100
- ).scalar()
1101
- if existing is not None:
1102
- return existing
1103
- hat = Hat(uri=uri, title=title)
1104
- session.add(hat)
1105
- session.commit()
1106
- return hat
466
+ )
1107
467
 
1108
- def set_presence_sent(self, participant_pk: int) -> None:
1109
- with self.session() as session:
1110
- session.execute(
1111
- update(Participant)
1112
- .where(Participant.id == participant_pk)
1113
- .values(presence_sent=True)
1114
- )
1115
- session.commit()
468
+ @staticmethod
469
+ def get_all(session: Session, user_pk: int) -> Iterator[Room]:
470
+ yield from session.scalars(select(Room).where(Room.user_account_id == user_pk))
1116
471
 
1117
- def set_affiliation(self, participant_pk: int, affiliation: MucAffiliation) -> None:
1118
- with self.session() as session:
1119
- session.execute(
1120
- update(Participant)
1121
- .where(Participant.id == participant_pk)
1122
- .values(affiliation=affiliation)
1123
- )
1124
- session.commit()
1125
472
 
1126
- def set_role(self, participant_pk: int, role: MucRole) -> None:
1127
- with self.session() as session:
1128
- session.execute(
1129
- update(Participant)
1130
- .where(Participant.id == participant_pk)
1131
- .values(role=role)
1132
- )
1133
- session.commit()
473
+ class ParticipantStore:
474
+ def __init__(self, session: Session) -> None:
475
+ session.execute(delete(Participant))
1134
476
 
1135
- def set_hats(self, participant_pk: int, hats: list[HatTuple]) -> None:
1136
- with self.session() as session:
1137
- part = session.execute(
1138
- select(Participant).where(Participant.id == participant_pk)
1139
- ).scalar()
1140
- if part is None:
1141
- raise ValueError
1142
- part.hats.clear()
1143
- for h in hats:
1144
- hat = self.add_hat(*h)
1145
- if hat in part.hats:
1146
- continue
1147
- part.hats.append(hat)
1148
- session.commit()
1149
-
1150
- def delete(self, participant_pk: int) -> None:
1151
- with self.session() as session:
1152
- session.execute(delete(Participant).where(Participant.id == participant_pk))
1153
-
1154
- def get_count(self, room_pk: int) -> int:
1155
- with self.session() as session:
1156
- return session.query(
1157
- count(Participant.id).filter(Participant.room_id == room_pk)
1158
- ).scalar()
477
+ @staticmethod
478
+ def get_all(
479
+ session, room_pk: int, user_included: bool = True
480
+ ) -> Iterator[Participant]:
481
+ query = select(Participant).where(Participant.room_id == room_pk)
482
+ if not user_included:
483
+ query = query.where(~Participant.is_user)
484
+ yield from session.scalars(query).unique()
1159
485
 
1160
486
 
1161
- class BobStore(EngineMixin):
487
+ class BobStore:
1162
488
  _ATTR_MAP = {
1163
489
  "sha-1": "sha_1",
1164
490
  "sha1": "sha_1",
@@ -1174,8 +500,7 @@ class BobStore(EngineMixin):
1174
500
  "sha_512": hashlib.sha512,
1175
501
  }
1176
502
 
1177
- def __init__(self, *a, **k):
1178
- super().__init__(*a, **k)
503
+ def __init__(self) -> None:
1179
504
  self.root_dir = config.HOME_DIR / "slidge_stickers"
1180
505
  self.root_dir.mkdir(exist_ok=True)
1181
506
 
@@ -1187,20 +512,19 @@ class BobStore(EngineMixin):
1187
512
  alg_name, digest = self.__split_cid(cid)
1188
513
  attr = self._ATTR_MAP.get(alg_name)
1189
514
  if attr is None:
1190
- log.warning("Unknown hash algo: %s", alg_name)
515
+ log.warning("Unknown hash algorithm: %s", alg_name)
1191
516
  return None
1192
517
  return getattr(Bob, attr) == digest
1193
518
 
1194
- def get(self, cid: str) -> Bob | None:
1195
- with self.session() as session:
1196
- try:
1197
- return session.query(Bob).filter(self.__get_condition(cid)).scalar()
1198
- except ValueError:
1199
- log.warning("Cannot get Bob with CID: %s", cid)
1200
- return None
519
+ def get(self, session: Session, cid: str) -> Bob | None:
520
+ try:
521
+ return session.query(Bob).filter(self.__get_condition(cid)).scalar()
522
+ except ValueError:
523
+ log.warning("Cannot get Bob with CID: %s", cid)
524
+ return None
1201
525
 
1202
- def get_sticker(self, cid: str) -> Sticker | None:
1203
- bob = self.get(cid)
526
+ def get_sticker(self, session: Session, cid: str) -> Sticker | None:
527
+ bob = self.get(session, cid)
1204
528
  if bob is None:
1205
529
  return None
1206
530
  return Sticker(
@@ -1209,8 +533,10 @@ class BobStore(EngineMixin):
1209
533
  {h: getattr(bob, h) for h in self._ALG_MAP},
1210
534
  )
1211
535
 
1212
- def get_bob(self, _jid, _node, _ifrom, cid: str) -> BitsOfBinary | None:
1213
- stored = self.get(cid)
536
+ def get_bob(
537
+ self, session: Session, _jid, _node, _ifrom, cid: str
538
+ ) -> BitsOfBinary | None:
539
+ stored = self.get(session, cid)
1214
540
  if stored is None:
1215
541
  return None
1216
542
  bob = BitsOfBinary()
@@ -1220,53 +546,45 @@ class BobStore(EngineMixin):
1220
546
  bob["cid"] = cid
1221
547
  return bob
1222
548
 
1223
- def del_bob(self, _jid, _node, _ifrom, cid: str) -> None:
1224
- with self.session() as orm:
1225
- try:
1226
- file_name = orm.scalar(
1227
- delete(Bob)
1228
- .where(self.__get_condition(cid))
1229
- .returning(Bob.file_name)
1230
- )
1231
- except ValueError:
1232
- log.warning("Cannot delete Bob with CID: %s", cid)
1233
- return None
1234
- if file_name is None:
1235
- log.warning("No BoB with CID: %s", cid)
1236
- return None
1237
- (self.root_dir / file_name).unlink()
1238
- orm.commit()
1239
-
1240
- def set_bob(self, _jid, _node, _ifrom, bob: BitsOfBinary) -> None:
549
+ def del_bob(self, session: Session, _jid, _node, _ifrom, cid: str) -> None:
550
+ try:
551
+ file_name = session.scalar(
552
+ delete(Bob).where(self.__get_condition(cid)).returning(Bob.file_name)
553
+ )
554
+ except ValueError:
555
+ log.warning("Cannot delete Bob with CID: %s", cid)
556
+ return None
557
+ if file_name is None:
558
+ log.warning("No BoB with CID: %s", cid)
559
+ return None
560
+ (self.root_dir / file_name).unlink()
561
+
562
+ def set_bob(self, session: Session, _jid, _node, _ifrom, bob: BitsOfBinary) -> None:
1241
563
  cid = bob["cid"]
1242
564
  try:
1243
565
  alg_name, digest = self.__split_cid(cid)
1244
566
  except ValueError:
1245
- log.warning("Cannot set Bob with CID: %s", cid)
567
+ log.warning("Invalid CID provided: %s", cid)
1246
568
  return
1247
569
  attr = self._ATTR_MAP.get(alg_name)
1248
570
  if attr is None:
1249
- log.warning("Cannot set BoB with unknown hash algo: %s", alg_name)
1250
- return None
1251
- with self.session() as orm:
1252
- existing = self.get(bob["cid"])
1253
- if existing is not None:
1254
- log.debug("Bob already known")
1255
- return
1256
- bytes_ = bob["data"]
1257
- path = self.root_dir / uuid.uuid4().hex
1258
- if bob["type"]:
1259
- path = path.with_suffix(guess_extension(bob["type"]) or "")
1260
- path.write_bytes(bytes_)
1261
- hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()}
1262
- if hashes[attr] != digest:
1263
- raise ValueError(
1264
- "The given CID does not correspond to the result of our hash"
1265
- )
1266
- row = Bob(file_name=path.name, content_type=bob["type"] or None, **hashes)
1267
- orm.add(row)
1268
- orm.commit()
571
+ log.warning("Cannot set Bob: Unknown algorithm type: %s", alg_name)
572
+ return
573
+ existing = self.get(session, bob["cid"])
574
+ if existing:
575
+ log.debug("Bob already exists")
576
+ return
577
+ bytes_ = bob["data"]
578
+ path = self.root_dir / uuid.uuid4().hex
579
+ if bob["type"]:
580
+ path = path.with_suffix(guess_extension(bob["type"]) or "")
581
+ path.write_bytes(bytes_)
582
+ hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()}
583
+ if hashes[attr] != digest:
584
+ path.unlink(missing_ok=True)
585
+ raise ValueError("Provided CID does not match calculated hash")
586
+ row = Bob(file_name=path.name, content_type=bob["type"] or None, **hashes)
587
+ session.add(row)
1269
588
 
1270
589
 
1271
590
  log = logging.getLogger(__name__)
1272
- _session: Optional[Session] = None