slidge 0.2.0a6__py3-none-any.whl → 0.2.0a9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- slidge/__version__.py +1 -1
- slidge/command/chat_command.py +15 -1
- slidge/command/user.py +2 -0
- slidge/contact/contact.py +46 -13
- slidge/contact/roster.py +0 -3
- slidge/core/gateway/base.py +15 -7
- slidge/core/gateway/session_dispatcher.py +16 -7
- slidge/core/mixins/attachment.py +51 -46
- slidge/core/mixins/avatar.py +11 -7
- slidge/core/pubsub.py +55 -68
- slidge/core/session.py +4 -4
- slidge/db/alembic/versions/3071e0fa69d4_add_contact_client_type.py +52 -0
- slidge/db/alembic/versions/5bd48bfdffa2_lift_room_legacy_id_constraint.py +19 -13
- slidge/db/alembic/versions/aa9d82a7f6ef_db_creation.py +5 -1
- slidge/db/alembic/versions/abba1ae0edb3_store_avatar_legacy_id_in_the_contact_.py +78 -0
- slidge/db/avatar.py +14 -49
- slidge/db/models.py +8 -5
- slidge/db/store.py +30 -16
- slidge/group/room.py +17 -4
- slidge/main.py +8 -3
- slidge/migration.py +15 -1
- slidge/util/test.py +8 -1
- slidge/util/types.py +4 -0
- {slidge-0.2.0a6.dist-info → slidge-0.2.0a9.dist-info}/METADATA +6 -2
- {slidge-0.2.0a6.dist-info → slidge-0.2.0a9.dist-info}/RECORD +28 -26
- {slidge-0.2.0a6.dist-info → slidge-0.2.0a9.dist-info}/LICENSE +0 -0
- {slidge-0.2.0a6.dist-info → slidge-0.2.0a9.dist-info}/WHEEL +0 -0
- {slidge-0.2.0a6.dist-info → slidge-0.2.0a9.dist-info}/entry_points.txt +0 -0
    
        slidge/core/pubsub.py
    CHANGED
    
    | @@ -22,11 +22,11 @@ from slixmpp.types import JidStr, OptJidStr | |
| 22 22 |  | 
| 23 23 | 
             
            from ..db.avatar import CachedAvatar, avatar_cache
         | 
| 24 24 | 
             
            from ..db.store import ContactStore, SlidgeStore
         | 
| 25 | 
            -
            from ..util.types import URL
         | 
| 26 25 | 
             
            from .mixins.lock import NamedLockMixin
         | 
| 27 26 |  | 
| 28 27 | 
             
            if TYPE_CHECKING:
         | 
| 29 | 
            -
                from  | 
| 28 | 
            +
                from ..contact.contact import LegacyContact
         | 
| 29 | 
            +
                from ..core.gateway.base import BaseGateway
         | 
| 30 30 |  | 
| 31 31 | 
             
            VCARD4_NAMESPACE = "urn:xmpp:vcard4"
         | 
| 32 32 |  | 
| @@ -116,7 +116,6 @@ class PubSubComponent(NamedLockMixin, BasePlugin): | |
| 116 116 | 
             
                            self._get_vcard,  # type:ignore
         | 
| 117 117 | 
             
                        )
         | 
| 118 118 | 
             
                    )
         | 
| 119 | 
            -
                    self.xmpp.add_event_handler("presence_available", self._on_presence_available)
         | 
| 120 119 |  | 
| 121 120 | 
             
                    disco = self.xmpp.plugin["xep_0030"]
         | 
| 122 121 | 
             
                    disco.add_identity("pubsub", "pep", self.component_name)
         | 
| @@ -125,63 +124,51 @@ class PubSubComponent(NamedLockMixin, BasePlugin): | |
| 125 124 | 
             
                    disco.add_feature("http://jabber.org/protocol/pubsub#retrieve-items")
         | 
| 126 125 | 
             
                    disco.add_feature("http://jabber.org/protocol/pubsub#persistent-items")
         | 
| 127 126 |  | 
| 128 | 
            -
                async def  | 
| 127 | 
            +
                async def __get_features(self, presence: Presence) -> list[str]:
         | 
| 128 | 
            +
                    from_ = presence.get_from()
         | 
| 129 | 
            +
                    ver_string = presence["caps"]["ver"]
         | 
| 130 | 
            +
                    if ver_string:
         | 
| 131 | 
            +
                        info = await self.xmpp.plugin["xep_0115"].get_caps(from_)
         | 
| 132 | 
            +
                    else:
         | 
| 133 | 
            +
                        info = None
         | 
| 134 | 
            +
                    if info is None:
         | 
| 135 | 
            +
                        async with self.lock(from_):
         | 
| 136 | 
            +
                            iq = await self.xmpp.plugin["xep_0030"].get_info(from_)
         | 
| 137 | 
            +
                        info = iq["disco_info"]
         | 
| 138 | 
            +
                    return info["features"]
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                async def on_presence_available(
         | 
| 141 | 
            +
                    self, p: Presence, contact: Optional["LegacyContact"]
         | 
| 142 | 
            +
                ):
         | 
| 129 143 | 
             
                    if p.get_plugin("muc_join", check=True) is not None:
         | 
| 130 144 | 
             
                        log.debug("Ignoring MUC presence here")
         | 
| 131 145 | 
             
                        return
         | 
| 132 146 |  | 
| 133 | 
            -
                    from_ = p.get_from()
         | 
| 134 | 
            -
                    ver_string = p["caps"]["ver"]
         | 
| 135 | 
            -
                    info = None
         | 
| 136 | 
            -
             | 
| 137 147 | 
             
                    to = p.get_to()
         | 
| 138 | 
            -
             | 
| 139 | 
            -
                    contact = None
         | 
| 140 | 
            -
                    # we don't want to push anything for contacts that are not in the user's roster
         | 
| 141 148 | 
             
                    if to != self.xmpp.boundjid.bare:
         | 
| 142 | 
            -
                         | 
| 143 | 
            -
             | 
| 144 | 
            -
                        if session is None:
         | 
| 149 | 
            +
                        # we don't want to push anything for contacts that are not in the user's roster
         | 
| 150 | 
            +
                        if contact is None or not contact.is_friend:
         | 
| 145 151 | 
             
                            return
         | 
| 146 152 |  | 
| 147 | 
            -
             | 
| 148 | 
            -
             | 
| 149 | 
            -
                            contact = await session.contacts.by_jid(to)
         | 
| 150 | 
            -
                        except XMPPError as e:
         | 
| 151 | 
            -
                            log.debug(
         | 
| 152 | 
            -
                                "Could not determine if %s was added to the roster: %s", to, e
         | 
| 153 | 
            -
                            )
         | 
| 154 | 
            -
                            return
         | 
| 155 | 
            -
                        except Exception as e:
         | 
| 156 | 
            -
                            log.warning("Could not determine if %s was added to the roster.", to)
         | 
| 157 | 
            -
                            log.exception(e)
         | 
| 158 | 
            -
                            return
         | 
| 159 | 
            -
                        if not contact.is_friend:
         | 
| 160 | 
            -
                            return
         | 
| 153 | 
            +
                    from_ = p.get_from()
         | 
| 154 | 
            +
                    features = await self.__get_features(p)
         | 
| 161 155 |  | 
| 162 | 
            -
                    if ver_string:
         | 
| 163 | 
            -
                        info = await self.xmpp.plugin["xep_0115"].get_caps(from_)
         | 
| 164 | 
            -
                    if info is None:
         | 
| 165 | 
            -
                        async with self.lock(from_):
         | 
| 166 | 
            -
                            iq = await self.xmpp.plugin["xep_0030"].get_info(from_)
         | 
| 167 | 
            -
                        info = iq["disco_info"]
         | 
| 168 | 
            -
                    features = info["features"]
         | 
| 169 156 | 
             
                    if AvatarMetadata.namespace + "+notify" in features:
         | 
| 170 157 | 
             
                        try:
         | 
| 171 | 
            -
                            pep_avatar = await self._get_authorized_avatar(p)
         | 
| 158 | 
            +
                            pep_avatar = await self._get_authorized_avatar(p, contact)
         | 
| 172 159 | 
             
                        except XMPPError:
         | 
| 173 160 | 
             
                            pass
         | 
| 174 161 | 
             
                        else:
         | 
| 175 162 | 
             
                            if pep_avatar.metadata is not None:
         | 
| 176 163 | 
             
                                await self.__broadcast(
         | 
| 177 164 | 
             
                                    data=pep_avatar.metadata,
         | 
| 178 | 
            -
                                    from_=p.get_to(),
         | 
| 165 | 
            +
                                    from_=p.get_to().bare,
         | 
| 179 166 | 
             
                                    to=from_,
         | 
| 180 167 | 
             
                                    id=pep_avatar.metadata["info"]["id"],
         | 
| 181 168 | 
             
                                )
         | 
| 182 169 | 
             
                    if UserNick.namespace + "+notify" in features:
         | 
| 183 170 | 
             
                        try:
         | 
| 184 | 
            -
                            pep_nick = await self._get_authorized_nick(p)
         | 
| 171 | 
            +
                            pep_nick = await self._get_authorized_nick(p, contact)
         | 
| 185 172 | 
             
                        except XMPPError:
         | 
| 186 173 | 
             
                            pass
         | 
| 187 174 | 
             
                        else:
         | 
| @@ -210,64 +197,64 @@ class PubSubComponent(NamedLockMixin, BasePlugin): | |
| 210 197 | 
             
                        node=VCARD4_NAMESPACE,
         | 
| 211 198 | 
             
                    )
         | 
| 212 199 |  | 
| 213 | 
            -
                async def  | 
| 200 | 
            +
                async def __get_contact(self, stanza: Union[Iq, Presence]):
         | 
| 201 | 
            +
                    session = self.xmpp.get_session_from_stanza(stanza)
         | 
| 202 | 
            +
                    return await session.contacts.by_jid(stanza.get_to())
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                async def _get_authorized_avatar(
         | 
| 205 | 
            +
                    self, stanza: Union[Iq, Presence], contact: Optional["LegacyContact"] = None
         | 
| 206 | 
            +
                ) -> PepAvatar:
         | 
| 214 207 | 
             
                    if stanza.get_to() == self.xmpp.boundjid.bare:
         | 
| 215 208 | 
             
                        item = PepAvatar()
         | 
| 216 209 | 
             
                        item.set_avatar_from_cache(avatar_cache.get_by_pk(self.xmpp.avatar_pk))
         | 
| 217 210 | 
             
                        return item
         | 
| 218 211 |  | 
| 219 | 
            -
                     | 
| 220 | 
            -
             | 
| 212 | 
            +
                    if contact is None:
         | 
| 213 | 
            +
                        contact = await self.__get_contact(stanza)
         | 
| 221 214 |  | 
| 222 215 | 
             
                    item = PepAvatar()
         | 
| 223 | 
            -
                     | 
| 224 | 
            -
             | 
| 225 | 
            -
                        stored = avatar_cache.get(
         | 
| 226 | 
            -
                            avatar_id if isinstance(avatar_id, URL) else str(avatar_id)
         | 
| 227 | 
            -
                        )
         | 
| 216 | 
            +
                    if contact.avatar_pk is not None:
         | 
| 217 | 
            +
                        stored = avatar_cache.get_by_pk(contact.avatar_pk)
         | 
| 228 218 | 
             
                        assert stored is not None
         | 
| 229 219 | 
             
                        item.set_avatar_from_cache(stored)
         | 
| 230 220 | 
             
                    return item
         | 
| 231 221 |  | 
| 232 | 
            -
                async def _get_authorized_nick( | 
| 222 | 
            +
                async def _get_authorized_nick(
         | 
| 223 | 
            +
                    self, stanza: Union[Iq, Presence], contact: Optional["LegacyContact"] = None
         | 
| 224 | 
            +
                ) -> PepNick:
         | 
| 233 225 | 
             
                    if stanza.get_to() == self.xmpp.boundjid.bare:
         | 
| 234 226 | 
             
                        return PepNick(self.xmpp.COMPONENT_NAME)
         | 
| 235 227 |  | 
| 236 | 
            -
                     | 
| 237 | 
            -
             | 
| 228 | 
            +
                    if contact is None:
         | 
| 229 | 
            +
                        contact = await self.__get_contact(stanza)
         | 
| 238 230 |  | 
| 239 | 
            -
                    if  | 
| 240 | 
            -
                        return PepNick( | 
| 231 | 
            +
                    if contact.name is not None:
         | 
| 232 | 
            +
                        return PepNick(contact.name)
         | 
| 241 233 | 
             
                    else:
         | 
| 242 234 | 
             
                        return PepNick()
         | 
| 243 235 |  | 
| 244 | 
            -
                 | 
| 245 | 
            -
                     | 
| 246 | 
            -
             | 
| 236 | 
            +
                def __reply_with(
         | 
| 237 | 
            +
                    self, iq: Iq, content: AvatarData | AvatarMetadata | None, item_id: str | None
         | 
| 238 | 
            +
                ) -> None:
         | 
| 247 239 | 
             
                    requested_items = iq["pubsub"]["items"]
         | 
| 240 | 
            +
             | 
| 248 241 | 
             
                    if len(requested_items) == 0:
         | 
| 249 | 
            -
                        self._reply_with_payload(iq,  | 
| 242 | 
            +
                        self._reply_with_payload(iq, content, item_id)
         | 
| 250 243 | 
             
                    else:
         | 
| 251 244 | 
             
                        for item in requested_items:
         | 
| 252 | 
            -
                            if item["id"] ==  | 
| 253 | 
            -
                                self._reply_with_payload(iq,  | 
| 245 | 
            +
                            if item["id"] == item_id:
         | 
| 246 | 
            +
                                self._reply_with_payload(iq, content, item_id)
         | 
| 254 247 | 
             
                                return
         | 
| 255 248 | 
             
                        else:
         | 
| 256 249 | 
             
                            raise XMPPError("item-not-found")
         | 
| 257 250 |  | 
| 258 | 
            -
                async def  | 
| 251 | 
            +
                async def _get_avatar_data(self, iq: Iq):
         | 
| 259 252 | 
             
                    pep_avatar = await self._get_authorized_avatar(iq)
         | 
| 253 | 
            +
                    self.__reply_with(iq, pep_avatar.data, pep_avatar.id)
         | 
| 260 254 |  | 
| 261 | 
            -
             | 
| 262 | 
            -
                     | 
| 263 | 
            -
             | 
| 264 | 
            -
                    else:
         | 
| 265 | 
            -
                        for item in requested_items:
         | 
| 266 | 
            -
                            if item["id"] == pep_avatar.id:
         | 
| 267 | 
            -
                                self._reply_with_payload(iq, pep_avatar.metadata, pep_avatar.id)
         | 
| 268 | 
            -
                                return
         | 
| 269 | 
            -
                        else:
         | 
| 270 | 
            -
                            raise XMPPError("item-not-found")
         | 
| 255 | 
            +
                async def _get_avatar_metadata(self, iq: Iq):
         | 
| 256 | 
            +
                    pep_avatar = await self._get_authorized_avatar(iq)
         | 
| 257 | 
            +
                    self.__reply_with(iq, pep_avatar.metadata, pep_avatar.id)
         | 
| 271 258 |  | 
| 272 259 | 
             
                async def _get_vcard(self, iq: Iq):
         | 
| 273 260 | 
             
                    # this is not the proper way that clients should retrieve VCards, but
         | 
    
        slidge/core/session.py
    CHANGED
    
    | @@ -73,8 +73,6 @@ class BaseSession( | |
| 73 73 | 
             
                session-specific.
         | 
| 74 74 | 
             
                """
         | 
| 75 75 |  | 
| 76 | 
            -
                http: aiohttp.ClientSession
         | 
| 77 | 
            -
             | 
| 78 76 | 
             
                MESSAGE_IDS_ARE_THREAD_IDS = False
         | 
| 79 77 | 
             
                """
         | 
| 80 78 | 
             
                Set this to True if the legacy service uses message IDs as thread IDs,
         | 
| @@ -106,8 +104,6 @@ class BaseSession( | |
| 106 104 | 
             
                        self
         | 
| 107 105 | 
             
                    )
         | 
| 108 106 |  | 
| 109 | 
            -
                    self.http = self.xmpp.http
         | 
| 110 | 
            -
             | 
| 111 107 | 
             
                    self.thread_creation_lock = asyncio.Lock()
         | 
| 112 108 |  | 
| 113 109 | 
             
                    self.__cached_presence: Optional[CachedPresence] = None
         | 
| @@ -118,6 +114,10 @@ class BaseSession( | |
| 118 114 | 
             
                def user(self) -> GatewayUser:
         | 
| 119 115 | 
             
                    return self.xmpp.store.users.get(self.user_jid)  # type:ignore
         | 
| 120 116 |  | 
| 117 | 
            +
                @property
         | 
| 118 | 
            +
                def http(self) -> aiohttp.ClientSession:
         | 
| 119 | 
            +
                    return self.xmpp.http
         | 
| 120 | 
            +
             | 
| 121 121 | 
             
                def __remove_task(self, fut):
         | 
| 122 122 | 
             
                    self.log.debug("Removing fut %s", fut)
         | 
| 123 123 | 
             
                    self.__tasks.remove(fut)
         | 
| @@ -0,0 +1,52 @@ | |
| 1 | 
            +
            """Add Contact.client_type
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Revision ID: 3071e0fa69d4
         | 
| 4 | 
            +
            Revises: abba1ae0edb3
         | 
| 5 | 
            +
            Create Date: 2024-07-30 23:12:49.345593
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from typing import Sequence, Union
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import sqlalchemy as sa
         | 
| 12 | 
            +
            from alembic import op
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # revision identifiers, used by Alembic.
         | 
| 15 | 
            +
            revision: str = "3071e0fa69d4"
         | 
| 16 | 
            +
            down_revision: Union[str, None] = "abba1ae0edb3"
         | 
| 17 | 
            +
            branch_labels: Union[str, Sequence[str], None] = None
         | 
| 18 | 
            +
            depends_on: Union[str, Sequence[str], None] = None
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def upgrade() -> None:
         | 
| 22 | 
            +
                # ### commands auto generated by Alembic - please adjust! ###
         | 
| 23 | 
            +
                with op.batch_alter_table("contact", schema=None) as batch_op:
         | 
| 24 | 
            +
                    batch_op.add_column(
         | 
| 25 | 
            +
                        sa.Column(
         | 
| 26 | 
            +
                            "client_type",
         | 
| 27 | 
            +
                            sa.Enum(
         | 
| 28 | 
            +
                                "bot",
         | 
| 29 | 
            +
                                "console",
         | 
| 30 | 
            +
                                "game",
         | 
| 31 | 
            +
                                "handheld",
         | 
| 32 | 
            +
                                "pc",
         | 
| 33 | 
            +
                                "phone",
         | 
| 34 | 
            +
                                "sms",
         | 
| 35 | 
            +
                                "tablet",
         | 
| 36 | 
            +
                                "web",
         | 
| 37 | 
            +
                                native_enum=False,
         | 
| 38 | 
            +
                            ),
         | 
| 39 | 
            +
                            nullable=False,
         | 
| 40 | 
            +
                            server_default=sa.text("pc"),
         | 
| 41 | 
            +
                        )
         | 
| 42 | 
            +
                    )
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # ### end Alembic commands ###
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def downgrade() -> None:
         | 
| 48 | 
            +
                # ### commands auto generated by Alembic - please adjust! ###
         | 
| 49 | 
            +
                with op.batch_alter_table("contact", schema=None) as batch_op:
         | 
| 50 | 
            +
                    batch_op.drop_column("client_type")
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                # ### end Alembic commands ###
         | 
| @@ -20,19 +20,25 @@ depends_on: Union[str, Sequence[str], None] = None | |
| 20 20 |  | 
| 21 21 |  | 
| 22 22 | 
             
            def upgrade() -> None:
         | 
| 23 | 
            -
                 | 
| 24 | 
            -
                     | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
                    batch_op | 
| 31 | 
            -
                         | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
                         | 
| 35 | 
            -
             | 
| 23 | 
            +
                try:
         | 
| 24 | 
            +
                    with op.batch_alter_table(
         | 
| 25 | 
            +
                        "room",
         | 
| 26 | 
            +
                        schema=None,
         | 
| 27 | 
            +
                        # without copy_from, the newly created table keeps the constraints
         | 
| 28 | 
            +
                        # we actually want to ditch.
         | 
| 29 | 
            +
                        copy_from=Room.__table__,  # type:ignore
         | 
| 30 | 
            +
                    ) as batch_op:
         | 
| 31 | 
            +
                        batch_op.create_unique_constraint(
         | 
| 32 | 
            +
                            "uq_room_user_account_id_jid", ["user_account_id", "jid"]
         | 
| 33 | 
            +
                        )
         | 
| 34 | 
            +
                        batch_op.create_unique_constraint(
         | 
| 35 | 
            +
                            "uq_room_user_account_id_legacy_id", ["user_account_id", "legacy_id"]
         | 
| 36 | 
            +
                        )
         | 
| 37 | 
            +
                except Exception:
         | 
| 38 | 
            +
                    # happens when migration is not needed
         | 
| 39 | 
            +
                    # wouldn't be necessary if the constraint was named in the first place,
         | 
| 40 | 
            +
                    # cf https://alembic.sqlalchemy.org/en/latest/naming.html
         | 
| 41 | 
            +
                    pass
         | 
| 36 42 |  | 
| 37 43 |  | 
| 38 44 | 
             
            def downgrade() -> None:
         | 
| @@ -60,7 +60,11 @@ def downgrade() -> None: | |
| 60 60 | 
             
            def migrate_from_shelf(accounts: sa.Table) -> None:
         | 
| 61 61 | 
             
                from slidge import global_config
         | 
| 62 62 |  | 
| 63 | 
            -
                 | 
| 63 | 
            +
                home = getattr(global_config, "HOME_DIR", None)
         | 
| 64 | 
            +
                if home is None:
         | 
| 65 | 
            +
                    return
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                db_file = home / "slidge.db"
         | 
| 64 68 | 
             
                if not db_file.exists():
         | 
| 65 69 | 
             
                    return
         | 
| 66 70 |  | 
| @@ -0,0 +1,78 @@ | |
| 1 | 
            +
            """Store avatar legacy ID in the Contact and Room table
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Revision ID: abba1ae0edb3
         | 
| 4 | 
            +
            Revises: 8b993243a536
         | 
| 5 | 
            +
            Create Date: 2024-07-29 15:44:41.557388
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from typing import Sequence, Union
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import sqlalchemy as sa
         | 
| 12 | 
            +
            from alembic import op
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from slidge.db.models import Contact, Room
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # revision identifiers, used by Alembic.
         | 
| 17 | 
            +
            revision: str = "abba1ae0edb3"
         | 
| 18 | 
            +
            down_revision: Union[str, None] = "8b993243a536"
         | 
| 19 | 
            +
            branch_labels: Union[str, Sequence[str], None] = None
         | 
| 20 | 
            +
            depends_on: Union[str, Sequence[str], None] = None
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def upgrade() -> None:
         | 
| 24 | 
            +
                conn = op.get_bind()
         | 
| 25 | 
            +
                room_avatars = conn.execute(
         | 
| 26 | 
            +
                    sa.text(
         | 
| 27 | 
            +
                        "select room.id, avatar.legacy_id from room join avatar on room.avatar_id = avatar.id"
         | 
| 28 | 
            +
                    )
         | 
| 29 | 
            +
                ).all()
         | 
| 30 | 
            +
                contact_avatars = conn.execute(
         | 
| 31 | 
            +
                    sa.text(
         | 
| 32 | 
            +
                        "select contact.id, avatar.legacy_id from contact join avatar on contact.avatar_id = avatar.id"
         | 
| 33 | 
            +
                    )
         | 
| 34 | 
            +
                ).all()
         | 
| 35 | 
            +
                with op.batch_alter_table("contact", schema=None) as batch_op:
         | 
| 36 | 
            +
                    batch_op.add_column(sa.Column("avatar_legacy_id", sa.String(), nullable=True))
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                with op.batch_alter_table("room", schema=None) as batch_op:
         | 
| 39 | 
            +
                    batch_op.add_column(sa.Column("avatar_legacy_id", sa.String(), nullable=True))
         | 
| 40 | 
            +
                    batch_op.create_unique_constraint(
         | 
| 41 | 
            +
                        "uq_room_user_account_id_jid", ["user_account_id", "jid"]
         | 
| 42 | 
            +
                    )
         | 
| 43 | 
            +
                    batch_op.create_unique_constraint(
         | 
| 44 | 
            +
                        "uq_room_user_account_id_legacy_id", ["user_account_id", "legacy_id"]
         | 
| 45 | 
            +
                    )
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                for room_pk, avatar_legacy_id in room_avatars:
         | 
| 48 | 
            +
                    conn.execute(
         | 
| 49 | 
            +
                        sa.update(Room)
         | 
| 50 | 
            +
                        .where(Room.id == room_pk)
         | 
| 51 | 
            +
                        .values(avatar_legacy_id=avatar_legacy_id)
         | 
| 52 | 
            +
                    )
         | 
| 53 | 
            +
                for contact_pk, avatar_legacy_id in contact_avatars:
         | 
| 54 | 
            +
                    conn.execute(
         | 
| 55 | 
            +
                        sa.update(Contact)
         | 
| 56 | 
            +
                        .where(Contact.id == contact_pk)
         | 
| 57 | 
            +
                        .values(avatar_legacy_id=avatar_legacy_id)
         | 
| 58 | 
            +
                    )
         | 
| 59 | 
            +
                # conn.commit()
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                with op.batch_alter_table("avatar", schema=None) as batch_op:
         | 
| 62 | 
            +
                    batch_op.drop_column("legacy_id")
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def downgrade() -> None:
         | 
| 66 | 
            +
                # ### commands auto generated by Alembic - please adjust! ###
         | 
| 67 | 
            +
                with op.batch_alter_table("room", schema=None) as batch_op:
         | 
| 68 | 
            +
                    batch_op.drop_constraint("uq_room_user_account_id_legacy_id", type_="unique")
         | 
| 69 | 
            +
                    batch_op.drop_constraint("uq_room_user_account_id_jid", type_="unique")
         | 
| 70 | 
            +
                    batch_op.drop_column("avatar_legacy_id")
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                with op.batch_alter_table("contact", schema=None) as batch_op:
         | 
| 73 | 
            +
                    batch_op.drop_column("avatar_legacy_id")
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                with op.batch_alter_table("avatar", schema=None) as batch_op:
         | 
| 76 | 
            +
                    batch_op.add_column(sa.Column("legacy_id", sa.VARCHAR(), nullable=True))
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                # ### end Alembic commands ###
         | 
    
        slidge/db/avatar.py
    CHANGED
    
    | @@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor | |
| 7 7 | 
             
            from dataclasses import dataclass
         | 
| 8 8 | 
             
            from http import HTTPStatus
         | 
| 9 9 | 
             
            from pathlib import Path
         | 
| 10 | 
            -
            from typing import  | 
| 10 | 
            +
            from typing import Optional
         | 
| 11 11 |  | 
| 12 12 | 
             
            import aiohttp
         | 
| 13 13 | 
             
            from multidict import CIMultiDictProxy
         | 
| @@ -18,7 +18,7 @@ from sqlalchemy import select | |
| 18 18 | 
             
            from slidge.core import config
         | 
| 19 19 | 
             
            from slidge.db.models import Avatar
         | 
| 20 20 | 
             
            from slidge.db.store import AvatarStore
         | 
| 21 | 
            -
            from slidge.util.types import URL, AvatarType | 
| 21 | 
            +
            from slidge.util.types import URL, AvatarType
         | 
| 22 22 |  | 
| 23 23 |  | 
| 24 24 | 
             
            @dataclass
         | 
| @@ -62,7 +62,6 @@ class AvatarCache: | |
| 62 62 | 
             
                dir: Path
         | 
| 63 63 | 
             
                http: aiohttp.ClientSession
         | 
| 64 64 | 
             
                store: AvatarStore
         | 
| 65 | 
            -
                legacy_avatar_type: Callable[[str], Any] = str
         | 
| 66 65 |  | 
| 67 66 | 
             
                def __init__(self):
         | 
| 68 67 | 
             
                    self._thread_pool = ThreadPoolExecutor(config.AVATAR_RESAMPLING_THREADS)
         | 
| @@ -119,15 +118,6 @@ class AvatarCache: | |
| 119 118 | 
             
                    headers = self.__get_http_headers(cached)
         | 
| 120 119 | 
             
                    return await self.__is_modified(url, headers)
         | 
| 121 120 |  | 
| 122 | 
            -
                def get(self, unique_id: LegacyFileIdType | URL) -> Optional[CachedAvatar]:
         | 
| 123 | 
            -
                    if isinstance(unique_id, URL):
         | 
| 124 | 
            -
                        stored = self.store.get_by_url(unique_id)
         | 
| 125 | 
            -
                    else:
         | 
| 126 | 
            -
                        stored = self.store.get_by_legacy_id(str(unique_id))
         | 
| 127 | 
            -
                    if stored is None:
         | 
| 128 | 
            -
                        return None
         | 
| 129 | 
            -
                    return CachedAvatar.from_store(stored, self.dir)
         | 
| 130 | 
            -
             | 
| 131 121 | 
             
                def get_by_pk(self, pk: int) -> CachedAvatar:
         | 
| 132 122 | 
             
                    stored = self.store.get_by_pk(pk)
         | 
| 133 123 | 
             
                    assert stored is not None
         | 
| @@ -141,40 +131,21 @@ class AvatarCache: | |
| 141 131 | 
             
                        return open_image(avatar)
         | 
| 142 132 | 
             
                    raise TypeError("Avatar must be bytes or a Path", avatar)
         | 
| 143 133 |  | 
| 144 | 
            -
                async def convert_or_get(
         | 
| 145 | 
            -
                    self,
         | 
| 146 | 
            -
                    avatar: AvatarType,
         | 
| 147 | 
            -
                    unique_id: Optional[LegacyFileIdType],
         | 
| 148 | 
            -
                ) -> CachedAvatar:
         | 
| 149 | 
            -
                    if unique_id is not None:
         | 
| 150 | 
            -
                        cached = self.get(str(unique_id))
         | 
| 151 | 
            -
                        if cached is not None:
         | 
| 152 | 
            -
                            return cached
         | 
| 153 | 
            -
             | 
| 134 | 
            +
                async def convert_or_get(self, avatar: AvatarType) -> CachedAvatar:
         | 
| 154 135 | 
             
                    if isinstance(avatar, (URL, str)):
         | 
| 155 | 
            -
                         | 
| 156 | 
            -
                             | 
| 157 | 
            -
             | 
| 158 | 
            -
                                 | 
| 159 | 
            -
                                     | 
| 160 | 
            -
             | 
| 161 | 
            -
             | 
| 162 | 
            -
                                 | 
| 163 | 
            -
             | 
| 164 | 
            -
                                    return CachedAvatar.from_store(stored, self.dir)
         | 
| 165 | 
            -
                        else:
         | 
| 166 | 
            -
                            img, _ = await self.__download(avatar, {})
         | 
| 167 | 
            -
                            response_headers = None
         | 
| 136 | 
            +
                        with self.store.session():
         | 
| 137 | 
            +
                            stored = self.store.get_by_url(avatar)
         | 
| 138 | 
            +
                            try:
         | 
| 139 | 
            +
                                img, response_headers = await self.__download(
         | 
| 140 | 
            +
                                    avatar, self.__get_http_headers(stored)
         | 
| 141 | 
            +
                                )
         | 
| 142 | 
            +
                            except NotModified:
         | 
| 143 | 
            +
                                assert stored is not None
         | 
| 144 | 
            +
                                return CachedAvatar.from_store(stored, self.dir)
         | 
| 168 145 | 
             
                    else:
         | 
| 169 146 | 
             
                        img = await self._get_image(avatar)
         | 
| 170 147 | 
             
                        response_headers = None
         | 
| 171 148 | 
             
                    with self.store.session() as orm:
         | 
| 172 | 
            -
                        stored = orm.execute(
         | 
| 173 | 
            -
                            select(Avatar).where(Avatar.legacy_id == str(unique_id))
         | 
| 174 | 
            -
                        ).scalar()
         | 
| 175 | 
            -
                        if stored is not None and stored.url is None:
         | 
| 176 | 
            -
                            return CachedAvatar.from_store(stored, self.dir)
         | 
| 177 | 
            -
             | 
| 178 149 | 
             
                        resize = (size := config.AVATAR_SIZE) and any(x > size for x in img.size)
         | 
| 179 150 | 
             
                        if resize:
         | 
| 180 151 | 
             
                            await asyncio.get_event_loop().run_in_executor(
         | 
| @@ -188,8 +159,8 @@ class AvatarCache: | |
| 188 159 | 
             
                        if (
         | 
| 189 160 | 
             
                            not resize
         | 
| 190 161 | 
             
                            and img.format == "PNG"
         | 
| 191 | 
            -
                            and isinstance( | 
| 192 | 
            -
                            and (path := Path( | 
| 162 | 
            +
                            and isinstance(avatar, (str, Path))
         | 
| 163 | 
            +
                            and (path := Path(avatar))
         | 
| 193 164 | 
             
                            and path.exists()
         | 
| 194 165 | 
             
                        ):
         | 
| 195 166 | 
             
                            img_bytes = path.read_bytes()
         | 
| @@ -206,11 +177,6 @@ class AvatarCache: | |
| 206 177 | 
             
                        stored = orm.execute(select(Avatar).where(Avatar.hash == hash_)).scalar()
         | 
| 207 178 |  | 
| 208 179 | 
             
                        if stored is not None:
         | 
| 209 | 
            -
                            if unique_id is not None:
         | 
| 210 | 
            -
                                log.warning("Updating 'unique' IDs of a known avatar.")
         | 
| 211 | 
            -
                                stored.legacy_id = str(unique_id)
         | 
| 212 | 
            -
                                orm.add(stored)
         | 
| 213 | 
            -
                                orm.commit()
         | 
| 214 180 | 
             
                            return CachedAvatar.from_store(stored, self.dir)
         | 
| 215 181 |  | 
| 216 182 | 
             
                        stored = Avatar(
         | 
| @@ -218,7 +184,6 @@ class AvatarCache: | |
| 218 184 | 
             
                            hash=hash_,
         | 
| 219 185 | 
             
                            height=img.height,
         | 
| 220 186 | 
             
                            width=img.width,
         | 
| 221 | 
            -
                            legacy_id=None if unique_id is None else str(unique_id),
         | 
| 222 187 | 
             
                            url=avatar if isinstance(avatar, (URL, str)) else None,
         | 
| 223 188 | 
             
                        )
         | 
| 224 189 | 
             
                        if response_headers:
         | 
    
        slidge/db/models.py
    CHANGED
    
    | @@ -9,7 +9,7 @@ from slixmpp.types import MucAffiliation, MucRole | |
| 9 9 | 
             
            from sqlalchemy import ForeignKey, Index, UniqueConstraint
         | 
| 10 10 | 
             
            from sqlalchemy.orm import Mapped, mapped_column, relationship
         | 
| 11 11 |  | 
| 12 | 
            -
            from ..util.types import MucType
         | 
| 12 | 
            +
            from ..util.types import ClientType, MucType
         | 
| 13 13 | 
             
            from .meta import Base, JSONSerializable, JSONSerializableTypes
         | 
| 14 14 |  | 
| 15 15 |  | 
| @@ -114,9 +114,6 @@ class Avatar(Base): | |
| 114 114 | 
             
                height: Mapped[int] = mapped_column()
         | 
| 115 115 | 
             
                width: Mapped[int] = mapped_column()
         | 
| 116 116 |  | 
| 117 | 
            -
                # legacy network-wide unique identifier for the avatar
         | 
| 118 | 
            -
                legacy_id: Mapped[Optional[str]] = mapped_column(unique=True, nullable=True)
         | 
| 119 | 
            -
             | 
| 120 117 | 
             
                # this is only used when avatars are available as HTTP URLs and do not
         | 
| 121 118 | 
             
                # have a legacy_id
         | 
| 122 119 | 
             
                url: Mapped[Optional[str]] = mapped_column(default=None)
         | 
| @@ -171,6 +168,10 @@ class Contact(Base): | |
| 171 168 |  | 
| 172 169 | 
             
                participants: Mapped[list["Participant"]] = relationship(back_populates="contact")
         | 
| 173 170 |  | 
| 171 | 
            +
                avatar_legacy_id: Mapped[Optional[str]] = mapped_column(nullable=True)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                client_type: Mapped[ClientType] = mapped_column(nullable=False, default="pc")
         | 
| 174 | 
            +
             | 
| 174 175 |  | 
| 175 176 | 
             
            class ContactSent(Base):
         | 
| 176 177 | 
             
                """
         | 
| @@ -235,6 +236,8 @@ class Room(Base): | |
| 235 236 | 
             
                    back_populates="room", primaryjoin="Participant.room_id == Room.id"
         | 
| 236 237 | 
             
                )
         | 
| 237 238 |  | 
| 239 | 
            +
                avatar_legacy_id: Mapped[Optional[str]] = mapped_column(nullable=True)
         | 
| 240 | 
            +
             | 
| 238 241 |  | 
| 239 242 | 
             
            class ArchivedMessage(Base):
         | 
| 240 243 | 
             
                """
         | 
| @@ -366,7 +369,7 @@ class Participant(Base): | |
| 366 369 | 
             
                )
         | 
| 367 370 |  | 
| 368 371 | 
             
                contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
         | 
| 369 | 
            -
                contact: Mapped[Contact] = relationship(back_populates="participants")
         | 
| 372 | 
            +
                contact: Mapped[Contact] = relationship(lazy=False, back_populates="participants")
         | 
| 370 373 |  | 
| 371 374 | 
             
                is_user: Mapped[bool] = mapped_column(default=False)
         | 
| 372 375 |  |