plain.oauth 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
plain/oauth/models.py ADDED
@@ -0,0 +1,191 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from plain import models, transaction
4
+ from plain.auth import get_user_model
5
+ from plain.models.db import IntegrityError, OperationalError, ProgrammingError
6
+ from plain.preflight import Error
7
+ from plain.runtime import settings
8
+ from plain.utils import timezone
9
+
10
+ from .exceptions import OAuthUserAlreadyExistsError
11
+
12
+ if TYPE_CHECKING:
13
+ from .providers import OAuthToken, OAuthUser
14
+
15
+
16
+ # django check for deploy that ensures all provider keys in db are also in settings?
17
+
18
+
19
+ class OAuthConnection(models.Model):
20
+ created_at = models.DateTimeField(auto_now_add=True)
21
+ updated_at = models.DateTimeField(auto_now=True)
22
+
23
+ user = models.ForeignKey(
24
+ settings.AUTH_USER_MODEL,
25
+ on_delete=models.CASCADE,
26
+ related_name="oauth_connections",
27
+ )
28
+
29
+ # The key used to refer to this provider type (in settings)
30
+ provider_key = models.CharField(max_length=100, db_index=True)
31
+
32
+ # The unique ID of the user on the provider's system
33
+ provider_user_id = models.CharField(max_length=100, db_index=True)
34
+
35
+ # Token data
36
+ access_token = models.CharField(max_length=2000)
37
+ refresh_token = models.CharField(max_length=2000, blank=True)
38
+ access_token_expires_at = models.DateTimeField(blank=True, null=True)
39
+ refresh_token_expires_at = models.DateTimeField(blank=True, null=True)
40
+
41
+ class Meta:
42
+ unique_together = ("provider_key", "provider_user_id")
43
+ ordering = ("provider_key",)
44
+
45
+ def __str__(self):
46
+ return f"{self.provider_key}[{self.user}:{self.provider_user_id}]"
47
+
48
+ def refresh_access_token(self) -> None:
49
+ from .providers import OAuthToken, get_oauth_provider_instance
50
+
51
+ provider_instance = get_oauth_provider_instance(provider_key=self.provider_key)
52
+ oauth_token = OAuthToken(
53
+ access_token=self.access_token,
54
+ refresh_token=self.refresh_token,
55
+ access_token_expires_at=self.access_token_expires_at,
56
+ refresh_token_expires_at=self.refresh_token_expires_at,
57
+ )
58
+ refreshed_oauth_token = provider_instance.refresh_oauth_token(
59
+ oauth_token=oauth_token
60
+ )
61
+ self.set_token_fields(refreshed_oauth_token)
62
+ self.save()
63
+
64
+ def set_token_fields(self, oauth_token: "OAuthToken"):
65
+ self.access_token = oauth_token.access_token
66
+ self.refresh_token = oauth_token.refresh_token
67
+ self.access_token_expires_at = oauth_token.access_token_expires_at
68
+ self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
69
+
70
+ def set_user_fields(self, oauth_user: "OAuthUser"):
71
+ self.provider_user_id = oauth_user.id
72
+
73
+ def access_token_expired(self) -> bool:
74
+ return (
75
+ self.access_token_expires_at is not None
76
+ and self.access_token_expires_at < timezone.now()
77
+ )
78
+
79
+ def refresh_token_expired(self) -> bool:
80
+ return (
81
+ self.refresh_token_expires_at is not None
82
+ and self.refresh_token_expires_at < timezone.now()
83
+ )
84
+
85
+ @classmethod
86
+ def get_or_createuser(
87
+ cls, *, provider_key: str, oauth_token: "OAuthToken", oauth_user: "OAuthUser"
88
+ ) -> "OAuthConnection":
89
+ try:
90
+ connection = cls.objects.get(
91
+ provider_key=provider_key,
92
+ provider_user_id=oauth_user.id,
93
+ )
94
+ connection.set_token_fields(oauth_token)
95
+ connection.save()
96
+ return connection
97
+ except cls.DoesNotExist:
98
+ with transaction.atomic():
99
+ # If email needs to be unique, then we expect
100
+ # that to be taken care of on the user model itself
101
+ try:
102
+ user = get_user_model()(
103
+ username=oauth_user.username,
104
+ email=oauth_user.email,
105
+ )
106
+ user.full_clean()
107
+ user.save()
108
+ except IntegrityError:
109
+ raise OAuthUserAlreadyExistsError()
110
+
111
+ return cls.connect(
112
+ user=user,
113
+ provider_key=provider_key,
114
+ oauth_token=oauth_token,
115
+ oauth_user=oauth_user,
116
+ )
117
+
118
+ @classmethod
119
+ def connect(
120
+ cls,
121
+ *,
122
+ user: settings.AUTH_USER_MODEL,
123
+ provider_key: str,
124
+ oauth_token: "OAuthToken",
125
+ oauth_user: "OAuthUser",
126
+ ) -> "OAuthConnection":
127
+ """
128
+ Connect will either create a new connection or update an existing connection
129
+ """
130
+ try:
131
+ connection = cls.objects.get(
132
+ user=user,
133
+ provider_key=provider_key,
134
+ provider_user_id=oauth_user.id,
135
+ )
136
+ except cls.DoesNotExist:
137
+ # Create our own instance (not using get_or_create)
138
+ # so that any created signals contain the token fields too
139
+ connection = cls(
140
+ user=user,
141
+ provider_key=provider_key,
142
+ provider_user_id=oauth_user.id,
143
+ )
144
+
145
+ connection.set_user_fields(oauth_user)
146
+ connection.set_token_fields(oauth_token)
147
+ connection.save()
148
+
149
+ return connection
150
+
151
+ @classmethod
152
+ def check(cls, **kwargs):
153
+ """
154
+ A system check for ensuring that provider_keys in the database are also present in settings.
155
+
156
+ Note that the --database flag is required for this to work:
157
+ python manage.py check --database default
158
+ """
159
+ errors = super().check(**kwargs)
160
+
161
+ databases = kwargs.get("databases", None)
162
+ if not databases:
163
+ return errors
164
+
165
+ from .providers import get_provider_keys
166
+
167
+ for database in databases:
168
+ try:
169
+ keys_in_db = set(
170
+ cls.objects.using(database)
171
+ .values_list("provider_key", flat=True)
172
+ .distinct()
173
+ )
174
+ except (OperationalError, ProgrammingError):
175
+ # Check runs on manage.py migrate, and the table may not exist yet
176
+ # or it may not be installed on the particular database intentionally
177
+ continue
178
+
179
+ keys_in_settings = set(get_provider_keys())
180
+
181
+ if keys_in_db - keys_in_settings:
182
+ errors.append(
183
+ Error(
184
+ "The following OAuth providers are in the database but not in the settings: {}".format(
185
+ ", ".join(keys_in_db - keys_in_settings)
186
+ ),
187
+ id="plain.oauth.E001",
188
+ )
189
+ )
190
+
191
+ return errors
@@ -0,0 +1,192 @@
1
+ import datetime
2
+ import secrets
3
+ from typing import Any
4
+ from urllib.parse import urlencode
5
+
6
+ from plain.auth import login as auth_login
7
+ from plain.http import HttpRequest, Response, ResponseRedirect
8
+ from plain.runtime import settings
9
+ from plain.urls import reverse
10
+ from plain.utils.crypto import get_random_string
11
+ from plain.utils.module_loading import import_string
12
+
13
+ from .exceptions import OAuthError, OAuthStateMismatchError
14
+ from .models import OAuthConnection
15
+
16
+ SESSION_STATE_KEY = "plainoauth_state"
17
+ SESSION_NEXT_KEY = "plainoauth_next"
18
+
19
+
20
+ class OAuthToken:
21
+ def __init__(
22
+ self,
23
+ *,
24
+ access_token: str,
25
+ refresh_token: str = "",
26
+ access_token_expires_at: datetime.datetime = None,
27
+ refresh_token_expires_at: datetime.datetime = None,
28
+ ):
29
+ self.access_token = access_token
30
+ self.refresh_token = refresh_token
31
+ self.access_token_expires_at = access_token_expires_at
32
+ self.refresh_token_expires_at = refresh_token_expires_at
33
+
34
+
35
+ class OAuthUser:
36
+ def __init__(self, *, id: str, email: str, username: str = ""):
37
+ self.id = id
38
+ self.username = username
39
+ self.email = email
40
+
41
+ def __str__(self):
42
+ return self.email
43
+
44
+
45
+ class OAuthProvider:
46
+ authorization_url = ""
47
+
48
+ def __init__(
49
+ self,
50
+ *,
51
+ # Provided automatically
52
+ provider_key: str,
53
+ # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting
54
+ client_id: str,
55
+ client_secret: str,
56
+ # Not necessarily required, but commonly used
57
+ scope: str = "",
58
+ ):
59
+ self.provider_key = provider_key
60
+ self.client_id = client_id
61
+ self.client_secret = client_secret
62
+ self.scope = scope
63
+
64
+ def get_authorization_url_params(self, *, request: HttpRequest) -> dict:
65
+ return {
66
+ "redirect_uri": self.get_callback_url(request=request),
67
+ "client_id": self.get_client_id(),
68
+ "scope": self.get_scope(),
69
+ "state": self.generate_state(),
70
+ "response_type": "code",
71
+ }
72
+
73
+ def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken:
74
+ raise NotImplementedError()
75
+
76
+ def get_oauth_token(self, *, code: str, request: HttpRequest) -> OAuthToken:
77
+ raise NotImplementedError()
78
+
79
+ def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser:
80
+ raise NotImplementedError()
81
+
82
+ def get_authorization_url(self, *, request: HttpRequest) -> str:
83
+ return self.authorization_url
84
+
85
+ def get_client_id(self) -> str:
86
+ return self.client_id
87
+
88
+ def get_client_secret(self) -> str:
89
+ return self.client_secret
90
+
91
+ def get_scope(self) -> str:
92
+ return self.scope
93
+
94
+ def get_callback_url(self, *, request: HttpRequest) -> str:
95
+ url = reverse("oauth:callback", kwargs={"provider": self.provider_key})
96
+ return request.build_absolute_uri(url)
97
+
98
+ def generate_state(self) -> str:
99
+ return get_random_string(length=32)
100
+
101
+ def check_request_state(self, *, request: HttpRequest) -> None:
102
+ if error := request.GET.get("error"):
103
+ raise OAuthError(error)
104
+
105
+ state = request.GET["state"]
106
+ expected_state = request.session.pop(SESSION_STATE_KEY)
107
+ if not secrets.compare_digest(state, expected_state):
108
+ raise OAuthStateMismatchError()
109
+
110
+ def handle_login_request(
111
+ self, *, request: HttpRequest, redirect_to: str = ""
112
+ ) -> Response:
113
+ authorization_url = self.get_authorization_url(request=request)
114
+ authorization_params = self.get_authorization_url_params(request=request)
115
+
116
+ if "state" in authorization_params:
117
+ # Store the state in the session so we can check on callback
118
+ request.session[SESSION_STATE_KEY] = authorization_params["state"]
119
+
120
+ # Store next url in session so we can get it on the callback request
121
+ if redirect_to:
122
+ request.session[SESSION_NEXT_KEY] = redirect_to
123
+ elif "next" in request.POST:
124
+ request.session[SESSION_NEXT_KEY] = request.POST["next"]
125
+
126
+ # Sort authorization params for consistency
127
+ sorted_authorization_params = sorted(authorization_params.items())
128
+ redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params)
129
+ return ResponseRedirect(redirect_url)
130
+
131
+ def handle_connect_request(
132
+ self, *, request: HttpRequest, redirect_to: str = ""
133
+ ) -> Response:
134
+ return self.handle_login_request(request=request, redirect_to=redirect_to)
135
+
136
+ def handle_disconnect_request(self, *, request: HttpRequest) -> Response:
137
+ provider_user_id = request.POST["provider_user_id"]
138
+ connection = OAuthConnection.objects.get(
139
+ provider_key=self.provider_key, provider_user_id=provider_user_id
140
+ )
141
+ connection.delete()
142
+ redirect_url = self.get_disconnect_redirect_url(request=request)
143
+ return ResponseRedirect(redirect_url)
144
+
145
+ def handle_callback_request(self, *, request: HttpRequest) -> Response:
146
+ self.check_request_state(request=request)
147
+
148
+ oauth_token = self.get_oauth_token(code=request.GET["code"], request=request)
149
+ oauth_user = self.get_oauth_user(oauth_token=oauth_token)
150
+
151
+ if request.user:
152
+ connection = OAuthConnection.connect(
153
+ user=request.user,
154
+ provider_key=self.provider_key,
155
+ oauth_token=oauth_token,
156
+ oauth_user=oauth_user,
157
+ )
158
+ user = connection.user
159
+ else:
160
+ connection = OAuthConnection.get_or_createuser(
161
+ provider_key=self.provider_key,
162
+ oauth_token=oauth_token,
163
+ oauth_user=oauth_user,
164
+ )
165
+
166
+ user = connection.user
167
+
168
+ self.login(request=request, user=user)
169
+
170
+ redirect_url = self.get_login_redirect_url(request=request)
171
+ return ResponseRedirect(redirect_url)
172
+
173
+ def login(self, *, request: HttpRequest, user: Any) -> Response:
174
+ auth_login(request=request, user=user)
175
+
176
+ def get_login_redirect_url(self, *, request: HttpRequest) -> str:
177
+ return request.session.pop(SESSION_NEXT_KEY, "/")
178
+
179
+ def get_disconnect_redirect_url(self, *, request: HttpRequest) -> str:
180
+ return request.POST.get("next", "/")
181
+
182
+
183
+ def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider:
184
+ OAUTH_LOGIN_PROVIDERS = getattr(settings, "OAUTH_LOGIN_PROVIDERS", {})
185
+ provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"]
186
+ provider_class = import_string(provider_class_path)
187
+ provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {})
188
+ return provider_class(provider_key=provider_key, **provider_kwargs)
189
+
190
+
191
+ def get_provider_keys() -> list[str]:
192
+ return list(getattr(settings, "OAUTH_LOGIN_PROVIDERS", {}).keys())
@@ -0,0 +1,6 @@
1
+ {% extends "base.html" %}
2
+
3
+ {% block content %}
4
+ <h1>OAuth Error</h1>
5
+ <p>{{ oauth_error }}</p>
6
+ {% endblock %}
plain/oauth/urls.py ADDED
@@ -0,0 +1,24 @@
1
+ from plain.urls import include, path
2
+
3
+ from . import views
4
+
5
+ default_namespace = "oauth"
6
+
7
+ urlpatterns = [
8
+ path(
9
+ "<str:provider>/",
10
+ include(
11
+ [
12
+ # Login and Signup are both handled here, because the intent is the same
13
+ path("login/", views.OAuthLoginView, name="login"),
14
+ path("connect/", views.OAuthConnectView, name="connect"),
15
+ path(
16
+ "disconnect/",
17
+ views.OAuthDisconnectView,
18
+ name="disconnect",
19
+ ),
20
+ path("callback/", views.OAuthCallbackView, name="callback"),
21
+ ]
22
+ ),
23
+ ),
24
+ ]
plain/oauth/views.py ADDED
@@ -0,0 +1,87 @@
1
+ import logging
2
+
3
+ from plain.auth.views import AuthViewMixin
4
+ from plain.http import ResponseBadRequest, ResponseRedirect
5
+ from plain.templates import jinja
6
+ from plain.views import View
7
+
8
+ from .exceptions import (
9
+ OAuthError,
10
+ OAuthStateMismatchError,
11
+ OAuthUserAlreadyExistsError,
12
+ )
13
+ from .providers import get_oauth_provider_instance
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class OAuthLoginView(View):
19
+ def post(self):
20
+ request = self.request
21
+ provider = self.url_kwargs["provider"]
22
+ if request.user:
23
+ return ResponseRedirect("/")
24
+
25
+ provider_instance = get_oauth_provider_instance(provider_key=provider)
26
+ return provider_instance.handle_login_request(request=request)
27
+
28
+
29
+ class OAuthCallbackView(View):
30
+ """
31
+ The callback view is used for signup, login, and connect.
32
+ """
33
+
34
+ def get(self):
35
+ request = self.request
36
+ provider = self.url_kwargs["provider"]
37
+ provider_instance = get_oauth_provider_instance(provider_key=provider)
38
+ try:
39
+ return provider_instance.handle_callback_request(request=request)
40
+ except OAuthUserAlreadyExistsError:
41
+ template = jinja.get_template("oauth/error.html")
42
+ return ResponseBadRequest(
43
+ template.render(
44
+ {
45
+ "oauth_error": "A user already exists with this email address. Please log in first and then connect this OAuth provider to the existing account."
46
+ }
47
+ )
48
+ )
49
+ except OAuthStateMismatchError:
50
+ template = jinja.get_template("oauth/error.html")
51
+ return ResponseBadRequest(
52
+ template.render(
53
+ {
54
+ "oauth_error": "The state parameter did not match. Please try again."
55
+ }
56
+ )
57
+ )
58
+ except OAuthError as e:
59
+ logger.exception("OAuth error")
60
+ template = jinja.get_template("oauth/error.html")
61
+ return ResponseBadRequest(template.render({"oauth_error": str(e)}))
62
+
63
+
64
+ class OAuthConnectView(AuthViewMixin, View):
65
+ def post(self):
66
+ request = self.request
67
+ provider = self.url_kwargs["provider"]
68
+ provider_instance = get_oauth_provider_instance(provider_key=provider)
69
+ return provider_instance.handle_connect_request(request=request)
70
+
71
+
72
+ class OAuthDisconnectView(AuthViewMixin, View):
73
+ def post(self):
74
+ request = self.request
75
+ provider = self.url_kwargs["provider"]
76
+ provider_instance = get_oauth_provider_instance(provider_key=provider)
77
+ # try:
78
+ return provider_instance.handle_disconnect_request(request=request)
79
+ # except OAuthCannotDisconnectError:
80
+ # return render(
81
+ # request,
82
+ # "oauth/error.html",
83
+ # {
84
+ # "oauth_error": "This connection can't be removed. You must have a usable password or at least one active connection."
85
+ # },
86
+ # status=400,
87
+ # )
@@ -0,0 +1,28 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2023, Dropseed, LLC
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.