slidge 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (102) hide show
  1. slidge/__init__.py +3 -5
  2. slidge/__main__.py +2 -197
  3. slidge/__version__.py +5 -0
  4. slidge/command/adhoc.py +40 -17
  5. slidge/command/admin.py +24 -12
  6. slidge/command/base.py +10 -8
  7. slidge/command/categories.py +13 -3
  8. slidge/command/chat_command.py +29 -2
  9. slidge/command/register.py +32 -16
  10. slidge/command/user.py +106 -13
  11. slidge/contact/contact.py +254 -50
  12. slidge/contact/roster.py +124 -53
  13. slidge/core/config.py +19 -13
  14. slidge/core/dispatcher/__init__.py +3 -0
  15. slidge/core/{gateway → dispatcher}/caps.py +12 -8
  16. slidge/core/{gateway → dispatcher}/disco.py +10 -18
  17. slidge/core/dispatcher/message/__init__.py +10 -0
  18. slidge/core/dispatcher/message/chat_state.py +40 -0
  19. slidge/core/dispatcher/message/marker.py +62 -0
  20. slidge/core/dispatcher/message/message.py +397 -0
  21. slidge/core/dispatcher/muc/__init__.py +12 -0
  22. slidge/core/dispatcher/muc/admin.py +98 -0
  23. slidge/core/{gateway → dispatcher/muc}/mam.py +25 -17
  24. slidge/core/dispatcher/muc/misc.py +121 -0
  25. slidge/core/dispatcher/muc/owner.py +96 -0
  26. slidge/core/{gateway → dispatcher/muc}/ping.py +11 -17
  27. slidge/core/dispatcher/presence.py +176 -0
  28. slidge/core/dispatcher/registration.py +85 -0
  29. slidge/core/{gateway → dispatcher}/search.py +9 -16
  30. slidge/core/dispatcher/session_dispatcher.py +84 -0
  31. slidge/core/dispatcher/util.py +174 -0
  32. slidge/core/{gateway/vcard_temp.py → dispatcher/vcard.py} +35 -19
  33. slidge/core/{gateway/base.py → gateway.py} +176 -153
  34. slidge/core/mixins/__init__.py +11 -1
  35. slidge/core/mixins/attachment.py +106 -67
  36. slidge/core/mixins/avatar.py +94 -25
  37. slidge/core/mixins/base.py +10 -4
  38. slidge/core/mixins/db.py +18 -0
  39. slidge/core/mixins/disco.py +0 -10
  40. slidge/core/mixins/lock.py +10 -8
  41. slidge/core/mixins/message.py +11 -195
  42. slidge/core/mixins/message_maker.py +17 -9
  43. slidge/core/mixins/message_text.py +211 -0
  44. slidge/core/mixins/presence.py +17 -4
  45. slidge/core/pubsub.py +114 -288
  46. slidge/core/session.py +101 -40
  47. slidge/db/__init__.py +4 -0
  48. slidge/db/alembic/__init__.py +0 -0
  49. slidge/db/alembic/env.py +64 -0
  50. slidge/db/alembic/old_user_store.py +183 -0
  51. slidge/db/alembic/script.py.mako +26 -0
  52. slidge/db/alembic/versions/09f27f098baa_add_missing_attributes_in_room.py +36 -0
  53. slidge/db/alembic/versions/15b0bd83407a_remove_bogus_unique_constraints_on_room_.py +85 -0
  54. slidge/db/alembic/versions/2461390c0af2_store_contacts_caps_verstring_in_db.py +36 -0
  55. slidge/db/alembic/versions/29f5280c61aa_store_subject_setter_in_room.py +37 -0
  56. slidge/db/alembic/versions/2b1f45ab7379_store_room_subject_setter_by_nickname.py +41 -0
  57. slidge/db/alembic/versions/3071e0fa69d4_add_contact_client_type.py +52 -0
  58. slidge/db/alembic/versions/45c24cc73c91_add_bob.py +42 -0
  59. slidge/db/alembic/versions/5bd48bfdffa2_lift_room_legacy_id_constraint.py +61 -0
  60. slidge/db/alembic/versions/82a4af84b679_add_muc_history_filled.py +48 -0
  61. slidge/db/alembic/versions/8b993243a536_add_vcard_content_to_contact_table.py +43 -0
  62. slidge/db/alembic/versions/8d2ced764698_rely_on_db_to_store_contacts_rooms_and_.py +139 -0
  63. slidge/db/alembic/versions/aa9d82a7f6ef_db_creation.py +101 -0
  64. slidge/db/alembic/versions/abba1ae0edb3_store_avatar_legacy_id_in_the_contact_.py +79 -0
  65. slidge/db/alembic/versions/b33993e87db3_move_everything_to_persistent_db.py +214 -0
  66. slidge/db/alembic/versions/b64b1a793483_add_source_and_legacy_id_for_archived_.py +52 -0
  67. slidge/db/alembic/versions/c4a8ec35a0e8_per_room_user_nick.py +34 -0
  68. slidge/db/alembic/versions/e91195719c2c_store_users_avatars_persistently.py +26 -0
  69. slidge/db/avatar.py +205 -0
  70. slidge/db/meta.py +72 -0
  71. slidge/db/models.py +405 -0
  72. slidge/db/store.py +1257 -0
  73. slidge/group/archive.py +58 -14
  74. slidge/group/bookmarks.py +89 -65
  75. slidge/group/participant.py +107 -40
  76. slidge/group/room.py +402 -213
  77. slidge/main.py +202 -0
  78. slidge/migration.py +45 -1
  79. slidge/slixfix/__init__.py +31 -1
  80. slidge/{core/gateway → slixfix}/delivery_receipt.py +1 -1
  81. slidge/slixfix/roster.py +13 -4
  82. slidge/slixfix/xep_0292/vcard4.py +1 -87
  83. slidge/util/archive_msg.py +2 -1
  84. slidge/util/db.py +4 -228
  85. slidge/util/test.py +91 -4
  86. slidge/util/types.py +39 -4
  87. slidge/util/util.py +45 -2
  88. {slidge-0.1.3.dist-info → slidge-0.2.0.dist-info}/METADATA +10 -5
  89. slidge-0.2.0.dist-info/RECORD +131 -0
  90. slidge-0.2.0.dist-info/entry_points.txt +3 -0
  91. slidge/core/cache.py +0 -183
  92. slidge/core/gateway/__init__.py +0 -3
  93. slidge/core/gateway/muc_admin.py +0 -35
  94. slidge/core/gateway/presence.py +0 -95
  95. slidge/core/gateway/registration.py +0 -53
  96. slidge/core/gateway/session_dispatcher.py +0 -804
  97. slidge/util/schema.sql +0 -126
  98. slidge/util/sql.py +0 -508
  99. slidge-0.1.3.dist-info/RECORD +0 -96
  100. slidge-0.1.3.dist-info/entry_points.txt +0 -3
  101. {slidge-0.1.3.dist-info → slidge-0.2.0.dist-info}/LICENSE +0 -0
  102. {slidge-0.1.3.dist-info → slidge-0.2.0.dist-info}/WHEEL +0 -0
@@ -1,3 +1,4 @@
1
+ import base64
1
2
  import functools
2
3
  import logging
3
4
  import os
@@ -7,22 +8,23 @@ import stat
7
8
  import tempfile
8
9
  import warnings
9
10
  from datetime import datetime
10
- from mimetypes import guess_type
11
+ from itertools import chain
12
+ from mimetypes import guess_extension, guess_type
11
13
  from pathlib import Path
12
- from typing import IO, Collection, Optional, Sequence, Union
14
+ from typing import IO, AsyncIterator, Collection, Optional, Sequence, Union
13
15
  from urllib.parse import quote as urlquote
14
16
  from uuid import uuid4
15
17
  from xml.etree import ElementTree as ET
16
18
 
17
- import blurhash
18
- from PIL import Image
19
+ import thumbhash
20
+ from PIL import Image, ImageOps
19
21
  from slixmpp import JID, Message
20
22
  from slixmpp.exceptions import IqError
21
23
  from slixmpp.plugins.xep_0363 import FileUploadError
22
- from slixmpp.plugins.xep_0385.stanza import Sims
23
24
  from slixmpp.plugins.xep_0447.stanza import StatelessFileSharing
24
25
 
25
- from ...util.sql import db
26
+ from ...db.avatar import avatar_cache
27
+ from ...slixfix.xep_0264.stanza import Thumbnail
26
28
  from ...util.types import (
27
29
  LegacyAttachment,
28
30
  LegacyMessageType,
@@ -31,13 +33,13 @@ from ...util.types import (
31
33
  )
32
34
  from ...util.util import fix_suffix
33
35
  from .. import config
34
- from ..cache import avatar_cache
35
- from .message_maker import MessageMaker
36
+ from .message_text import TextMessageMixin
36
37
 
37
38
 
38
- class AttachmentMixin(MessageMaker):
39
- def send_text(self, *_, **k) -> Optional[Message]:
40
- raise NotImplementedError
39
+ class AttachmentMixin(TextMessageMixin):
40
+ def __init__(self, *a, **kw):
41
+ super().__init__(*a, **kw)
42
+ self.__store = self.xmpp.store.attachments
41
43
 
42
44
  async def __upload(
43
45
  self,
@@ -138,6 +140,7 @@ class AttachmentMixin(MessageMaker):
138
140
  async def __get_url(
139
141
  self,
140
142
  file_path: Optional[Path] = None,
143
+ async_data_stream: Optional[AsyncIterator[bytes]] = None,
141
144
  data_stream: Optional[IO[bytes]] = None,
142
145
  data: Optional[bytes] = None,
143
146
  file_url: Optional[str] = None,
@@ -146,13 +149,13 @@ class AttachmentMixin(MessageMaker):
146
149
  legacy_file_id: Optional[Union[str, int]] = None,
147
150
  ) -> tuple[bool, Optional[Path], str]:
148
151
  if legacy_file_id:
149
- cache = db.attachment_get_url(legacy_file_id)
152
+ cache = self.__store.get_url(str(legacy_file_id))
150
153
  if cache is not None:
151
154
  async with self.session.http.head(cache) as r:
152
155
  if r.status < 400:
153
156
  return False, None, cache
154
157
  else:
155
- db.attachment_remove(legacy_file_id)
158
+ self.__store.remove(str(legacy_file_id))
156
159
 
157
160
  if file_url and config.USE_ATTACHMENT_ORIGINAL_URLS:
158
161
  return False, None, file_url
@@ -165,7 +168,12 @@ class AttachmentMixin(MessageMaker):
165
168
  )
166
169
 
167
170
  if file_path is None:
168
- file_name = str(uuid4()) if file_name is None else file_name
171
+ if file_name is None:
172
+ file_name = str(uuid4())
173
+ if content_type is not None:
174
+ ext = guess_extension(content_type, strict=False) # type:ignore
175
+ if ext is not None:
176
+ file_name += ext
169
177
  temp_dir = Path(tempfile.mkdtemp())
170
178
  file_path = temp_dir / file_name
171
179
  if file_url:
@@ -173,14 +181,23 @@ class AttachmentMixin(MessageMaker):
173
181
  with file_path.open("wb") as f:
174
182
  f.write(await r.read())
175
183
 
176
- else:
177
- if data_stream is not None:
178
- data = data_stream.read()
184
+ elif data_stream is not None:
185
+ data = data_stream.read()
179
186
  if data is None:
180
187
  raise RuntimeError
181
188
 
182
189
  with file_path.open("wb") as f:
183
190
  f.write(data)
191
+ elif async_data_stream is not None:
192
+ # TODO: patch slixmpp to allow this as data source for
193
+ # upload_file() so we don't even have to write anything
194
+ # to disk.
195
+ with file_path.open("wb") as f:
196
+ async for chunk in async_data_stream:
197
+ f.write(chunk)
198
+ elif data is not None:
199
+ with file_path.open("wb") as f:
200
+ f.write(data)
184
201
 
185
202
  is_temp = not bool(config.NO_UPLOAD_PATH)
186
203
  else:
@@ -198,7 +215,7 @@ class AttachmentMixin(MessageMaker):
198
215
  local_path = file_path
199
216
  new_url = await self.__upload(file_path, file_name, content_type)
200
217
  if legacy_file_id:
201
- db.attachment_store_url(legacy_file_id, new_url)
218
+ self.__store.set_url(self.session.user_pk, str(legacy_file_id), new_url)
202
219
 
203
220
  return is_temp, local_path, new_url
204
221
 
@@ -210,37 +227,44 @@ class AttachmentMixin(MessageMaker):
210
227
  content_type: Optional[str] = None,
211
228
  caption: Optional[str] = None,
212
229
  file_name: Optional[str] = None,
213
- ):
214
- cache = db.attachment_get_sims(uploaded_url)
230
+ ) -> Thumbnail | None:
231
+ cache = self.__store.get_sims(uploaded_url)
215
232
  if cache:
216
- msg.append(Sims(xml=ET.fromstring(cache)))
217
- return
233
+ ref = self.xmpp["xep_0372"].stanza.Reference(xml=ET.fromstring(cache))
234
+ msg.append(ref)
235
+ if ref["sims"]["file"].get_plugin("thumbnail", check=True):
236
+ return ref["sims"]["file"]["thumbnail"]
237
+ else:
238
+ return None
218
239
 
219
240
  if not path:
220
- return
241
+ return None
221
242
 
222
- sims = self.xmpp["xep_0385"].get_sims(
243
+ ref = self.xmpp["xep_0385"].get_sims(
223
244
  path, [uploaded_url], content_type, caption
224
245
  )
225
246
  if file_name:
226
- sims["sims"]["file"]["name"] = file_name
247
+ ref["sims"]["file"]["name"] = file_name
248
+ thumbnail = None
227
249
  if content_type is not None and content_type.startswith("image"):
228
250
  try:
229
251
  h, x, y = await self.xmpp.loop.run_in_executor(
230
- avatar_cache._thread_pool, get_blurhash, path
252
+ avatar_cache._thread_pool, get_thumbhash, path
231
253
  )
232
254
  except Exception as e:
233
- log.debug("Could not generate a blurhash", exc_info=e)
255
+ log.debug("Could not generate a thumbhash", exc_info=e)
234
256
  else:
235
- thumbnail = sims["sims"]["file"]["thumbnail"]
257
+ thumbnail = ref["sims"]["file"]["thumbnail"]
236
258
  thumbnail["width"] = x
237
259
  thumbnail["height"] = y
238
- thumbnail["media-type"] = "image/blurhash"
239
- thumbnail["uri"] = "data:image/blurhash," + urlquote(h)
260
+ thumbnail["media-type"] = "image/thumbhash"
261
+ thumbnail["uri"] = "data:image/thumbhash;base64," + urlquote(h)
240
262
 
241
- db.attachment_store_sims(uploaded_url, str(sims))
263
+ self.__store.set_sims(uploaded_url, str(ref))
242
264
 
243
- msg.append(sims)
265
+ msg.append(ref)
266
+
267
+ return thumbnail
244
268
 
245
269
  def __set_sfs(
246
270
  self,
@@ -250,8 +274,9 @@ class AttachmentMixin(MessageMaker):
250
274
  content_type: Optional[str] = None,
251
275
  caption: Optional[str] = None,
252
276
  file_name: Optional[str] = None,
277
+ thumbnail: Optional[Thumbnail] = None,
253
278
  ):
254
- cache = db.attachment_get_sfs(uploaded_url)
279
+ cache = self.__store.get_sfs(uploaded_url)
255
280
  if cache:
256
281
  msg.append(StatelessFileSharing(xml=ET.fromstring(cache)))
257
282
  return
@@ -262,7 +287,9 @@ class AttachmentMixin(MessageMaker):
262
287
  sfs = self.xmpp["xep_0447"].get_sfs(path, [uploaded_url], content_type, caption)
263
288
  if file_name:
264
289
  sfs["file"]["name"] = file_name
265
- db.attachment_store_sfs(uploaded_url, str(sfs))
290
+ if thumbnail is not None:
291
+ sfs["file"].append(thumbnail)
292
+ self.__store.set_sfs(uploaded_url, str(sfs))
266
293
 
267
294
  msg.append(sfs)
268
295
 
@@ -274,6 +301,7 @@ class AttachmentMixin(MessageMaker):
274
301
  caption: Optional[str] = None,
275
302
  carbon=False,
276
303
  when: Optional[datetime] = None,
304
+ correction=False,
277
305
  **kwargs,
278
306
  ) -> list[Message]:
279
307
  msg["oob"]["url"] = uploaded_url
@@ -281,11 +309,19 @@ class AttachmentMixin(MessageMaker):
281
309
  if caption:
282
310
  m1 = self._send(msg, carbon=carbon, **kwargs)
283
311
  m2 = self.send_text(
284
- caption, legacy_msg_id=legacy_msg_id, when=when, carbon=carbon, **kwargs
312
+ caption,
313
+ legacy_msg_id=legacy_msg_id,
314
+ when=when,
315
+ carbon=carbon,
316
+ correction=correction,
317
+ **kwargs,
285
318
  )
286
319
  return [m1, m2] if m2 else [m1]
287
320
  else:
288
- self._set_msg_id(msg, legacy_msg_id)
321
+ if correction:
322
+ msg["replace"]["id"] = self._replace_id(legacy_msg_id)
323
+ else:
324
+ self._set_msg_id(msg, legacy_msg_id)
289
325
  return [self._send(msg, carbon=carbon, **kwargs)]
290
326
 
291
327
  async def send_file(
@@ -293,6 +329,7 @@ class AttachmentMixin(MessageMaker):
293
329
  file_path: Optional[Union[Path, str]] = None,
294
330
  legacy_msg_id: Optional[LegacyMessageType] = None,
295
331
  *,
332
+ async_data_stream: Optional[AsyncIterator[bytes]] = None,
296
333
  data_stream: Optional[IO[bytes]] = None,
297
334
  data: Optional[bytes] = None,
298
335
  file_url: Optional[str] = None,
@@ -309,6 +346,7 @@ class AttachmentMixin(MessageMaker):
309
346
  Send a single file from this :term:`XMPP Entity`.
310
347
 
311
348
  :param file_path: Path to the attachment
349
+ :param async_data_stream: Alternatively (and ideally) an AsyncIterator yielding bytes
312
350
  :param data_stream: Alternatively, a stream of bytes (such as a File object)
313
351
  :param data: Alternatively, a bytes object
314
352
  :param file_url: Alternatively, a URL
@@ -326,6 +364,16 @@ class AttachmentMixin(MessageMaker):
326
364
  carbon = kwargs.pop("carbon", False)
327
365
  mto = kwargs.pop("mto", None)
328
366
  store_multi = kwargs.pop("store_multi", True)
367
+ correction = kwargs.get("correction", False)
368
+ if correction and (original_xmpp_id := self._legacy_to_xmpp(legacy_msg_id)):
369
+ xmpp_ids = self.xmpp.store.multi.get_xmpp_ids(
370
+ self.session.user_pk, original_xmpp_id
371
+ )
372
+
373
+ for xmpp_id in xmpp_ids:
374
+ if xmpp_id == original_xmpp_id:
375
+ continue
376
+ self.retract(xmpp_id, thread)
329
377
  msg = self._make_message(
330
378
  when=when,
331
379
  reply_to=reply_to,
@@ -339,6 +387,7 @@ class AttachmentMixin(MessageMaker):
339
387
 
340
388
  is_temp, local_path, new_url = await self.__get_url(
341
389
  Path(file_path) if file_path else None,
390
+ async_data_stream,
342
391
  data_stream,
343
392
  data,
344
393
  file_url,
@@ -355,10 +404,12 @@ class AttachmentMixin(MessageMaker):
355
404
  self._set_msg_id(msg, legacy_msg_id)
356
405
  return None, [self._send(msg, **kwargs)]
357
406
 
358
- await self.__set_sims(
407
+ thumbnail = await self.__set_sims(
359
408
  msg, new_url, local_path, content_type, caption, file_name
360
409
  )
361
- self.__set_sfs(msg, new_url, local_path, content_type, caption, file_name)
410
+ self.__set_sfs(
411
+ msg, new_url, local_path, content_type, caption, file_name, thumbnail
412
+ )
362
413
  if is_temp and isinstance(local_path, Path):
363
414
  local_path.unlink()
364
415
  local_path.parent.rmdir()
@@ -472,35 +523,23 @@ class AttachmentMixin(MessageMaker):
472
523
  ids.append(stanza_id["id"])
473
524
  else:
474
525
  ids.append(msg.get_id())
475
- db.attachment_store_legacy_to_multi_xmpp_msg_ids(legacy_msg_id, ids)
476
-
477
-
478
- def get_blurhash(path: Path, n=9) -> tuple[str, int, int]:
479
- img = Image.open(path)
480
- width, height = img.size
481
- n = min(width, height, n)
482
- if width == height:
483
- x = y = n
484
- elif width > height:
485
- x = n
486
- y = round(n * height / width)
487
- else:
488
- x = round(n * width / height)
489
- y = n
490
- # There are 2 blurhash-python packages:
491
- # https://github.com/woltapp/blurhash-python
492
- # https://github.com/halcy/blurhash-python
493
- # With this hack we're compatible with both, which is useful for packaging
494
- # without using pyproject.toml, as most distro do
495
- try:
496
- hash_ = blurhash.encode(img, x, y)
497
- except TypeError:
498
- # We are using halcy's blurhash which expects
499
- # the 1st argument to be a 3-dimensional array
500
- import numpy # type:ignore
501
-
502
- hash_ = blurhash.encode(numpy.array(img.convert("RGB")), x, y)
503
- return hash_, width, height
526
+ self.xmpp.store.multi.set_xmpp_ids(
527
+ self.session.user_pk, str(legacy_msg_id), ids
528
+ )
529
+
530
+
531
+ def get_thumbhash(path: Path) -> tuple[str, int, int]:
532
+ with path.open("rb") as fp:
533
+ img = Image.open(fp)
534
+ width, height = img.size
535
+ img = img.convert("RGBA")
536
+ if width > 100 or height > 100:
537
+ img.thumbnail((100, 100))
538
+ img = ImageOps.exif_transpose(img)
539
+ rgba_2d = list(img.getdata())
540
+ rgba = list(chain(*rgba_2d))
541
+ ints = thumbhash.rgba_to_thumb_hash(img.width, img.height, rgba)
542
+ return base64.b64encode(bytes(ints)).decode(), width, height
504
543
 
505
544
 
506
545
  log = logging.getLogger(__name__)
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
5
5
 
6
6
  from slixmpp import JID
7
7
 
8
+ from ...db.avatar import CachedAvatar, avatar_cache
8
9
  from ...util.types import (
9
10
  URL,
10
11
  AnyBaseSession,
@@ -12,7 +13,6 @@ from ...util.types import (
12
13
  AvatarType,
13
14
  LegacyFileIdType,
14
15
  )
15
- from ..cache import avatar_cache
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from ..pubsub import PepAvatar
@@ -28,13 +28,14 @@ class AvatarMixin:
28
28
 
29
29
  jid: JID = NotImplemented
30
30
  session: AnyBaseSession = NotImplemented
31
- _avatar_pubsub_broadcast: bool = NotImplemented
32
31
  _avatar_bare_jid: bool = NotImplemented
33
32
 
34
33
  def __init__(self) -> None:
35
34
  super().__init__()
36
35
  self._set_avatar_task: Optional[Task] = None
36
+ self.__broadcast_task: Optional[Task] = None
37
37
  self.__avatar_unique_id: Optional[AvatarIdType] = None
38
+ self._avatar_pk: Optional[int] = None
38
39
 
39
40
  @property
40
41
  def __avatar_jid(self):
@@ -72,6 +73,10 @@ class AvatarMixin:
72
73
  name=f"Set avatar of {self} from property",
73
74
  )
74
75
 
76
+ @property
77
+ def avatar_pk(self) -> int | None:
78
+ return self._avatar_pk
79
+
75
80
  @staticmethod
76
81
  def __get_uid(a: Optional[AvatarType]) -> Optional[AvatarIdType]:
77
82
  if isinstance(a, str):
@@ -84,17 +89,39 @@ class AvatarMixin:
84
89
  return None
85
90
  raise TypeError("Bad avatar", a)
86
91
 
87
- async def __set_avatar(self, a: Optional[AvatarType], uid: Optional[AvatarIdType]):
92
+ async def __set_avatar(
93
+ self, a: Optional[AvatarType], uid: Optional[AvatarIdType], delete: bool
94
+ ):
88
95
  self.__avatar_unique_id = uid
89
- await self.session.xmpp.pubsub.set_avatar(
90
- jid=self.__avatar_jid,
91
- avatar=a,
92
- unique_id=None if isinstance(uid, URL) else uid,
93
- broadcast_to=self.session.user.jid.bare,
94
- broadcast=self._avatar_pubsub_broadcast,
95
- )
96
+
97
+ if a is None:
98
+ cached_avatar = None
99
+ self._avatar_pk = None
100
+ else:
101
+ try:
102
+ cached_avatar = await avatar_cache.convert_or_get(a)
103
+ except Exception as e:
104
+ self.session.log.error("Failed to set avatar %s", a, exc_info=e)
105
+ self._avatar_pk = None
106
+ self.__avatar_unique_id = uid
107
+ return
108
+ self._avatar_pk = cached_avatar.pk
109
+
110
+ if self.__should_pubsub_broadcast():
111
+ await self.session.xmpp.pubsub.broadcast_avatar(
112
+ self.__avatar_jid, self.session.user_jid, cached_avatar
113
+ )
114
+
115
+ if delete and isinstance(a, Path):
116
+ a.unlink()
117
+
96
118
  self._post_avatar_update()
97
119
 
120
+ def __should_pubsub_broadcast(self):
121
+ return getattr(self, "is_friend", False) and getattr(
122
+ self, "added_to_roster", False
123
+ )
124
+
98
125
  async def _no_change(self, a: Optional[AvatarType], uid: Optional[AvatarIdType]):
99
126
  if a is None:
100
127
  return self.__avatar_unique_id is None
@@ -103,23 +130,27 @@ class AvatarMixin:
103
130
  if isinstance(uid, URL):
104
131
  if self.__avatar_unique_id != uid:
105
132
  return False
106
- return not await avatar_cache.url_has_changed(uid)
133
+ return not await avatar_cache.url_modified(uid)
107
134
  return self.__avatar_unique_id == uid
108
135
 
109
136
  async def set_avatar(
110
137
  self,
111
138
  a: Optional[AvatarType],
112
139
  avatar_unique_id: Optional[LegacyFileIdType] = None,
140
+ delete: bool = False,
113
141
  blocking=False,
114
142
  cancel=True,
115
143
  ) -> None:
116
144
  """
117
145
  Set an avatar for this entity
118
146
 
119
- :param a:
120
- :param avatar_unique_id:
121
- :param blocking:
122
- :param cancel:
147
+ :param a: The avatar, in one of the types slidge supports
148
+ :param avatar_unique_id: A globally unique ID for the avatar on the
149
+ legacy network
150
+ :param delete: If the avatar is provided as a Path, whether to delete
151
+ it once used or not.
152
+ :param blocking: Internal use by slidge for tests, do not use!
153
+ :param cancel: Internal use by slidge, do not use!
123
154
  """
124
155
  if avatar_unique_id is None and a is not None:
125
156
  avatar_unique_id = self.__get_uid(a)
@@ -128,7 +159,7 @@ class AvatarMixin:
128
159
  if cancel and self._set_avatar_task:
129
160
  self._set_avatar_task.cancel()
130
161
  awaitable = create_task(
131
- self.__set_avatar(a, avatar_unique_id),
162
+ self.__set_avatar(a, avatar_unique_id, delete),
132
163
  name=f"Set pubsub avatar of {self}",
133
164
  )
134
165
  if not self._set_avatar_task or self._set_avatar_task.done():
@@ -136,32 +167,70 @@ class AvatarMixin:
136
167
  if blocking:
137
168
  await awaitable
138
169
 
170
+ def get_cached_avatar(self) -> Optional["CachedAvatar"]:
171
+ if self._avatar_pk is None:
172
+ return None
173
+ return avatar_cache.get_by_pk(self._avatar_pk)
174
+
139
175
  def get_avatar(self) -> Optional["PepAvatar"]:
140
- if not self.__avatar_unique_id:
176
+ cached_avatar = self.get_cached_avatar()
177
+ if cached_avatar is None:
141
178
  return None
142
- return self.session.xmpp.pubsub.get_avatar(self.__avatar_jid)
179
+ from ..pubsub import PepAvatar
180
+
181
+ item = PepAvatar()
182
+ item.set_avatar_from_cache(cached_avatar)
183
+ return item
143
184
 
144
185
  def _post_avatar_update(self) -> None:
145
186
  return
146
187
 
188
+ def __get_cached_avatar_id(self):
189
+ i = self._get_cached_avatar_id()
190
+ if i is None:
191
+ return None
192
+ return self.session.xmpp.AVATAR_ID_TYPE(i)
193
+
194
+ def _get_cached_avatar_id(self) -> Optional[str]:
195
+ raise NotImplementedError
196
+
147
197
  async def avatar_wrap_update_info(self):
148
- cached_id = avatar_cache.get_cached_id_for(self.__avatar_jid)
198
+ cached_id = self.__get_cached_avatar_id()
149
199
  self.__avatar_unique_id = cached_id
150
200
  try:
151
201
  await self.update_info() # type:ignore
152
202
  except NotImplementedError:
153
203
  return
154
204
  new_id = self.avatar
155
- if isinstance(new_id, URL) and not await avatar_cache.url_has_changed(new_id):
205
+ if isinstance(new_id, URL) and not await avatar_cache.url_modified(new_id):
156
206
  return
157
207
  elif new_id != cached_id:
158
208
  # at this point it means that update_info set the avatar, and we don't
159
209
  # need to do anything else
160
210
  return
161
211
 
162
- await self.session.xmpp.pubsub.set_avatar_from_cache(
163
- self.__avatar_jid,
164
- new_id is None and cached_id is not None,
165
- self.session.user.jid.bare,
166
- self._avatar_pubsub_broadcast,
212
+ if self.__should_pubsub_broadcast():
213
+ if new_id is None and cached_id is None:
214
+ return
215
+ if self._avatar_pk is not None:
216
+ cached_avatar = avatar_cache.get_by_pk(self._avatar_pk)
217
+ else:
218
+ cached_avatar = None
219
+ self.__broadcast_task = self.session.xmpp.loop.create_task(
220
+ self.session.xmpp.pubsub.broadcast_avatar(
221
+ self.__avatar_jid, self.session.user_jid, cached_avatar
222
+ )
223
+ )
224
+
225
+ def _set_avatar_from_store(self, stored):
226
+ if stored.avatar_id is None:
227
+ return
228
+ if stored.avatar is None:
229
+ # seems to happen after avatar cleanup for some reason?
230
+ self.__avatar_unique_id = None
231
+ return
232
+ self.__avatar_unique_id = (
233
+ stored.avatar.legacy_id
234
+ if stored.avatar.legacy_id is not None
235
+ else URL(stored.avatar.url)
167
236
  )
@@ -6,9 +6,8 @@ from slixmpp import JID
6
6
  from ...util.types import MessageOrPresenceTypeVar
7
7
 
8
8
  if TYPE_CHECKING:
9
- from slidge.core.gateway import BaseGateway
10
- from slidge.core.session import BaseSession
11
- from slidge.util.db import GatewayUser
9
+ from ..gateway import BaseGateway
10
+ from ..session import BaseSession
12
11
 
13
12
 
14
13
  class MetaBase(ABCMeta):
@@ -18,11 +17,18 @@ class MetaBase(ABCMeta):
18
17
  class Base:
19
18
  session: "BaseSession" = NotImplemented
20
19
  xmpp: "BaseGateway" = NotImplemented
21
- user: "GatewayUser" = NotImplemented
22
20
 
23
21
  jid: JID = NotImplemented
24
22
  name: str = NotImplemented
25
23
 
24
+ @property
25
+ def user_jid(self):
26
+ return self.session.user_jid
27
+
28
+ @property
29
+ def user_pk(self):
30
+ return self.session.user_pk
31
+
26
32
 
27
33
  class BaseSender(Base):
28
34
  def _send(
@@ -0,0 +1,18 @@
1
+ from contextlib import contextmanager
2
+
3
+
4
+ class UpdateInfoMixin:
5
+ """
6
+ This mixin just adds a context manager that prevents commiting to the DB
7
+ on every attribute change.
8
+ """
9
+
10
+ def __init__(self, *args, **kwargs):
11
+ super().__init__(*args, **kwargs)
12
+ self._updating_info = False
13
+
14
+ @contextmanager
15
+ def updating_info(self):
16
+ self._updating_info = True
17
+ yield
18
+ self._updating_info = False
@@ -13,10 +13,6 @@ class BaseDiscoMixin(Base):
13
13
  DISCO_NAME: str = NotImplemented
14
14
  DISCO_LANG = None
15
15
 
16
- def __init__(self):
17
- super().__init__()
18
- self.__caps_cache: Optional[str] = None
19
-
20
16
  def _get_disco_name(self):
21
17
  if self.DISCO_NAME is NotImplemented:
22
18
  return self.xmpp.COMPONENT_NAME
@@ -44,17 +40,11 @@ class BaseDiscoMixin(Base):
44
40
  return info
45
41
 
46
42
  async def get_caps_ver(self, jid: OptJid = None, node: Optional[str] = None):
47
- if self.__caps_cache:
48
- return self.__caps_cache
49
43
  info = await self.get_disco_info(jid, node)
50
44
  caps = self.xmpp.plugin["xep_0115"]
51
45
  ver = caps.generate_verstring(info, caps.hash)
52
- self.__caps_cache = ver
53
46
  return ver
54
47
 
55
- def reset_caps_cache(self):
56
- self.__caps_cache = None
57
-
58
48
 
59
49
  class ChatterDiscoMixin(BaseDiscoMixin):
60
50
  AVATAR = True
@@ -15,14 +15,16 @@ class NamedLockMixin:
15
15
  locks = self.__locks
16
16
  if not locks.get(id_):
17
17
  locks[id_] = asyncio.Lock()
18
- async with locks[id_]:
19
- log.trace("acquired %s", id_) # type:ignore
20
- yield
21
- log.trace("releasing %s", id_) # type:ignore
22
- waiters = locks[id_]._waiters # type:ignore
23
- if not waiters:
24
- del locks[id_]
25
- log.trace("erasing %s", id_) # type:ignore
18
+ try:
19
+ async with locks[id_]:
20
+ log.trace("acquired %s", id_) # type:ignore
21
+ yield
22
+ finally:
23
+ log.trace("releasing %s", id_) # type:ignore
24
+ waiters = locks[id_]._waiters # type:ignore
25
+ if not waiters:
26
+ del locks[id_]
27
+ log.trace("erasing %s", id_) # type:ignore
26
28
 
27
29
  def get_lock(self, id_: Hashable):
28
30
  return self.__locks.get(id_)