hakai_api 1.5.2__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hakai_api/__init__.py CHANGED
@@ -1,3 +1,24 @@
1
- from hakai_api.Client import Client
1
+ """Hakai API Python Client.
2
2
 
3
- __all__ = [Client]
3
+ A Python library for making authenticated HTTP requests to the Hakai API
4
+ resource server. Extends the functionality of the Python requests library
5
+ to supply Hakai OAuth2 credentials with URL requests.
6
+
7
+ The client supports both web-based authentication (copy/paste credentials)
8
+ and desktop OAuth2 flows with PKCE for secure credential management.
9
+
10
+ Example:
11
+ Basic usage:
12
+
13
+ >>> from hakai_api import Client
14
+ >>> client = Client()
15
+ >>> response = client.get("/eims/views/output/stations")
16
+ >>> data = response.json()
17
+
18
+ Classes:
19
+ Client: Main API client class for authenticated requests.
20
+ """
21
+
22
+ from hakai_api.client import Client
23
+
24
+ __all__ = ["Client"]
@@ -0,0 +1,7 @@
1
+ """Authentication strategies for the Hakai API client."""
2
+
3
+ from .base import AuthStrategy
4
+ from .desktop import DesktopAuthStrategy
5
+ from .web import WebAuthStrategy
6
+
7
+ __all__ = ["AuthStrategy", "WebAuthStrategy", "DesktopAuthStrategy"]
hakai_api/auth/base.py ADDED
@@ -0,0 +1,256 @@
1
+ """Base authentication strategy interface."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from abc import ABC, abstractmethod
7
+ from datetime import datetime
8
+ from time import mktime
9
+ from typing import TYPE_CHECKING
10
+
11
+ import requests
12
+ from loguru import logger
13
+
14
+ if TYPE_CHECKING:
15
+ from pathlib import Path
16
+
17
+
18
+ class AuthStrategy(ABC):
19
+ """Abstract base class for authentication strategies."""
20
+
21
+ def __init__(self, api_root: str, credentials_file: Path, **kwargs: object) -> None:
22
+ """Initialize the authentication strategy.
23
+
24
+ Args:
25
+ api_root: The base url of the hakai api.
26
+ login_page: The url of the login page to direct users to.
27
+ credentials_file: The path to the credentials file.
28
+ **kwargs: Additional strategy-specific parameters.
29
+ """
30
+ self.api_root = api_root
31
+ self.credentials_file = credentials_file
32
+
33
+ @abstractmethod
34
+ def get_credentials(self) -> dict:
35
+ """Get authentication credentials using this strategy.
36
+
37
+ Returns:
38
+ A dictionary containing the authentication credentials.
39
+
40
+ Raises:
41
+ ValueError: If credentials could not be obtained.
42
+ """
43
+ pass
44
+
45
+ @property
46
+ @abstractmethod
47
+ def client_type(self) -> str:
48
+ """Get the client type for this authentication strategy.
49
+
50
+ Returns:
51
+ The client type string (e.g., 'web', 'desktop').
52
+ """
53
+ pass
54
+
55
+ def save_credentials_to_file(self, credentials: dict) -> None:
56
+ """Save the credentials object to a file.
57
+
58
+ Args:
59
+ credentials: Credentials object.
60
+
61
+ Raises:
62
+ OSError: If file cannot be created or written to.
63
+ TypeError: If credentials cannot be serialized to JSON.
64
+ """
65
+ try:
66
+ # Ensure parent directory exists
67
+ self.credentials_file.parent.mkdir(parents=True, exist_ok=True)
68
+ with self.credentials_file.open("w") as outfile:
69
+ json.dump(credentials, outfile)
70
+ logger.trace(f"Credentials saved to {self.credentials_file}")
71
+ except (OSError, TypeError) as e:
72
+ logger.error(f"Failed to save credentials to file: {e}")
73
+ raise
74
+
75
+ def get_credentials_from_file(self) -> dict:
76
+ """Get user credentials from a cached file.
77
+
78
+ Loads and validates credentials from the cached credentials file.
79
+
80
+ Returns:
81
+ A dict containing the credentials with required keys and proper types.
82
+ """
83
+ with self.credentials_file.open() as infile:
84
+ result = json.load(infile)
85
+ result = self._check_keys_convert_types(result)
86
+ return result
87
+
88
+ def file_credentials_are_valid(self) -> bool:
89
+ """Check if the cached credentials exist and are valid.
90
+
91
+ Validates that the credentials file exists, can be parsed,
92
+ contains required fields, and has not expired.
93
+
94
+ Returns:
95
+ True if the credentials are valid, False otherwise.
96
+ """
97
+ if not self.credentials_file.is_file():
98
+ logger.trace("No cached credentials file found")
99
+ return False
100
+
101
+ try:
102
+ credentials = self.get_credentials_from_file()
103
+ expires_at = credentials["expires_at"]
104
+ except (KeyError, ValueError, OSError, json.JSONDecodeError) as e:
105
+ logger.warning(f"Invalid cached credentials file, removing: {e}")
106
+ try:
107
+ self.credentials_file.unlink()
108
+ except OSError:
109
+ pass # File might already be gone
110
+ return False
111
+
112
+ now = int(mktime(datetime.now().timetuple()) + datetime.now().microsecond / 1000000.0)
113
+
114
+ if now > expires_at:
115
+ logger.debug("Cached credentials have expired, removing")
116
+ self.reset_credentials()
117
+ return False
118
+
119
+ logger.trace("Cached credentials are valid")
120
+ return True
121
+
122
+ def reset_credentials(self) -> None:
123
+ """Remove the cached credentials file.
124
+
125
+ Deletes the credentials file from the filesystem if it exists.
126
+ """
127
+ if self.credentials_file.is_file():
128
+ logger.debug("Removing cached credentials file")
129
+ self.credentials_file.unlink()
130
+ else:
131
+ logger.trace("No cached credentials file to remove")
132
+
133
+ def parse_credentials_string(self, credentials: str) -> dict:
134
+ """Parse a credentials string into a dictionary.
135
+
136
+ Args:
137
+ credentials: The credentials string.
138
+
139
+ Returns:
140
+ A dictionary containing the credentials.
141
+
142
+ Raises:
143
+ ValueError: If the string format is invalid or cannot be split properly.
144
+ AttributeError: If the string lacks expected string methods.
145
+ KeyError: If required credential keys are missing after parsing.
146
+ """
147
+ logger.trace("Parsing credentials string")
148
+ try:
149
+ result = dict(map(lambda x: x.split("="), credentials.split("&")))
150
+ result = self._check_keys_convert_types(result)
151
+ logger.trace("Successfully parsed and validated credentials string")
152
+ return result
153
+ except (ValueError, AttributeError, KeyError) as e:
154
+ logger.error(f"Failed to parse credentials string: {e}")
155
+ raise
156
+
157
+ def _check_keys_convert_types(self, credentials: dict) -> dict:
158
+ """Check and clean the credentials.
159
+
160
+ Validates that required keys are present and converts string values
161
+ to appropriate types (expires_at and expires_in to integers).
162
+
163
+ Args:
164
+ credentials: credentials dictionary to validate and clean.
165
+
166
+ Returns:
167
+ updated credentials dictionary with proper types.
168
+
169
+ Raises:
170
+ ValueError: if required keys (access_token, token_type, expires_at)
171
+ are missing from the credentials dictionary.
172
+ """
173
+ missing_keys = [key for key in ["access_token", "token_type", "expires_at"] if key not in credentials]
174
+ if len(missing_keys) > 0:
175
+ logger.error(f"Credentials missing required keys: {missing_keys}")
176
+ raise ValueError(f"Credentials string is missing required keys: {str(missing_keys)}.")
177
+
178
+ # Convert expires_at to int
179
+ try:
180
+ credentials["expires_at"] = int(float(credentials["expires_at"]))
181
+ logger.trace(f"Credentials expire at timestamp: {credentials['expires_at']}")
182
+ except (ValueError, TypeError) as e:
183
+ logger.error(f"Invalid expires_at value: {e}")
184
+ raise ValueError(f"Invalid expires_at value in credentials: {e}")
185
+
186
+ # If expires_in is present, convert to int
187
+ if "expires_in" in credentials:
188
+ try:
189
+ credentials["expires_in"] = int(float(credentials["expires_in"]))
190
+ except (ValueError, TypeError) as e:
191
+ logger.error(f"Invalid expires_in value: {e}")
192
+ raise ValueError(f"Invalid expires_in value in credentials: {e}")
193
+
194
+ return credentials
195
+
196
+ def _are_credentials_expired(self, credentials: dict) -> bool:
197
+ """Check if the provided credentials are expired.
198
+
199
+ Args:
200
+ credentials: Credentials dictionary to check.
201
+
202
+ Returns:
203
+ True if credentials are expired, False otherwise.
204
+ """
205
+ try:
206
+ expires_at = credentials.get("expires_at")
207
+ if expires_at is None:
208
+ return False # If no expiry, assume valid
209
+
210
+ now = int(mktime(datetime.now().timetuple()) + datetime.now().microsecond / 1000000.0)
211
+ return now > expires_at
212
+ except (TypeError, ValueError):
213
+ return True # If we can't parse the expiry, consider it expired
214
+
215
+ def refresh_token(self, credentials: dict) -> dict | None:
216
+ """Refresh the access token using the refresh token.
217
+
218
+ Args:
219
+ credentials: Current credentials dictionary containing refresh_token.
220
+
221
+ Returns:
222
+ Updated credentials dictionary if successful, None otherwise.
223
+ """
224
+ if "refresh_token" not in credentials:
225
+ logger.trace("No refresh token available, cannot refresh")
226
+ return None
227
+
228
+ logger.trace("Attempting to refresh access token")
229
+
230
+ refresh_url = f"{self.api_root}/auth/refresh"
231
+ data = {
232
+ "refresh_token": credentials["refresh_token"],
233
+ "client_type": self.client_type,
234
+ }
235
+
236
+ try:
237
+ response = requests.post(refresh_url, json=data, timeout=10)
238
+
239
+ if response.status_code != 200:
240
+ logger.warning(f"Token refresh failed with status {response.status_code}")
241
+ return None
242
+
243
+ new_tokens = response.json()
244
+
245
+ # Update credentials
246
+ updated_credentials = credentials.copy()
247
+ updated_credentials["access_token"] = new_tokens["access_token"]
248
+ updated_credentials["expires_at"] = new_tokens["expires_at"]
249
+ updated_credentials["expires_in"] = new_tokens["expires_in"]
250
+
251
+ logger.trace("Access token refreshed successfully")
252
+ return updated_credentials
253
+
254
+ except (requests.RequestException, json.JSONDecodeError, KeyError) as e:
255
+ logger.error(f"Token refresh failed with exception: {e}")
256
+ return None
@@ -0,0 +1,263 @@
1
+ """Desktop OAuth authentication strategy with PKCE."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import secrets
8
+ import webbrowser
9
+ from http.server import BaseHTTPRequestHandler, HTTPServer
10
+ from typing import Any
11
+ from urllib.parse import parse_qs, urlencode, urlparse
12
+
13
+ import pkce
14
+ import requests
15
+ from loguru import logger
16
+
17
+ from .base import AuthStrategy
18
+
19
+
20
+ class DesktopAuthStrategy(AuthStrategy):
21
+ """Desktop OAuth authentication strategy using PKCE flow.
22
+
23
+ This strategy implements the OAuth2 Authorization Code flow with PKCE
24
+ (Proof Key for Code Exchange) for native desktop applications.
25
+ """
26
+
27
+ def __init__(self, api_root: str, local_port: int = 65500, callback_timeout: int = 120, **kwargs: object) -> None:
28
+ """Initialize the desktop authentication strategy.
29
+
30
+ Args:
31
+ api_root: The base url of the hakai api.
32
+ local_port: Port for local callback server.
33
+ callback_timeout: Timeout for callback server in seconds.
34
+ **kwargs: Additional parameters.
35
+ """
36
+ super().__init__(api_root, **kwargs)
37
+ self.local_port = local_port
38
+
39
+ # OAuth state variables
40
+ self._state = None
41
+ self._code_verifier = None
42
+ self._authorization_code = None
43
+ self._callback_timeout = callback_timeout
44
+
45
+ def get_credentials(self) -> dict:
46
+ """Get user credentials using desktop OAuth flow with PKCE.
47
+
48
+ First checks for cached credentials, environment variables, or initiates
49
+ the OAuth flow if none are available.
50
+
51
+ Returns:
52
+ A dict containing the credentials.
53
+ """
54
+ # Try environment variable first
55
+ env_credentials = os.getenv("HAKAI_API_CREDENTIALS")
56
+ if env_credentials is not None:
57
+ logger.trace("Loading credentials from environment variable")
58
+ try:
59
+ parsed_creds = self.parse_credentials_string(env_credentials)
60
+ # Check if environment credentials are expired
61
+ if self._are_credentials_expired(parsed_creds):
62
+ logger.warning("Environment variable credentials have expired")
63
+ else:
64
+ return parsed_creds
65
+ except (ValueError, KeyError) as e:
66
+ logger.warning(f"Invalid environment variable credentials: {e}")
67
+
68
+ # Try cached credentials
69
+ if self.file_credentials_are_valid():
70
+ logger.trace("Loading cached credentials from file")
71
+ return self.get_credentials_from_file()
72
+
73
+ # Start OAuth flow
74
+ logger.debug("No valid cached credentials found, starting desktop authentication flow")
75
+ return self._get_credentials_from_desktop_oauth()
76
+
77
+ def _get_credentials_from_desktop_oauth(self) -> dict:
78
+ """Get user credentials using desktop OAuth flow with PKCE.
79
+
80
+ Returns:
81
+ A dict containing the credentials.
82
+
83
+ Raises:
84
+ ValueError: If credentials could not be loaded.
85
+ """
86
+ # Generate PKCE parameters
87
+ self._code_verifier, code_challenge = pkce.generate_pkce_pair()
88
+
89
+ # Generate state for CSRF protection
90
+ self._state = secrets.token_urlsafe(32)
91
+
92
+ # Build authorization URL for desktop endpoint
93
+ params = {
94
+ "redirect_uri": f"http://127.0.0.1:{self.local_port}/callback",
95
+ "code_challenge": code_challenge,
96
+ "code_challenge_method": "S256",
97
+ "state": self._state,
98
+ }
99
+
100
+ # Use the desktop auth endpoint
101
+ auth_url = f"{self.api_root}/auth/desktop?{urlencode(params)}"
102
+
103
+ webbrowser.open(auth_url)
104
+
105
+ # Start local server to receive callback
106
+ logger.trace(f"Starting local callback server on port {self.local_port}")
107
+ self._authorization_code = self._wait_for_callback()
108
+
109
+ if not self._authorization_code:
110
+ logger.error("Failed to receive authorization code from OAuth callback")
111
+ raise ValueError("Failed to receive authorization code")
112
+
113
+ logger.trace("Successfully received authorization code, exchanging for tokens")
114
+ # Exchange code for tokens
115
+ tokens = self._exchange_code_for_tokens()
116
+
117
+ # Convert desktop token response to match web format
118
+ credentials = {
119
+ "access_token": tokens["access_token"],
120
+ "token_type": tokens["token_type"],
121
+ "expires_at": tokens["expires_at"],
122
+ "expires_in": tokens["expires_in"],
123
+ }
124
+
125
+ # Store refresh token if provided
126
+ if "refresh_token" in tokens:
127
+ credentials["refresh_token"] = tokens["refresh_token"]
128
+ logger.trace("Desktop OAuth completed successfully with refresh token")
129
+ else:
130
+ logger.trace("Desktop OAuth completed successfully without refresh token")
131
+
132
+ return credentials
133
+
134
+ def _wait_for_callback(self) -> str | None:
135
+ """Start a local HTTP server to receive the OAuth callback.
136
+
137
+ Starts a local HTTP server on the configured port to handle the OAuth
138
+ callback redirect. Validates the state parameter and extracts the
139
+ authorization code from the callback parameters.
140
+
141
+ Returns:
142
+ The authorization code from the OAuth callback.
143
+
144
+ Raises:
145
+ ValueError: If state mismatch occurs, OAuth error is returned,
146
+ or no authorization code is received.
147
+ """
148
+ authorization_code = None
149
+ server_error = None
150
+
151
+ def _get_callback_html() -> str:
152
+ """Load the HTML callback page from file.
153
+
154
+ Returns:
155
+ The HTML callback page.
156
+ """
157
+ from pathlib import Path
158
+
159
+ html_file = Path(__file__).parent / "desktop_callback.html"
160
+ try:
161
+ with html_file.open("r", encoding="utf-8") as f:
162
+ return f.read()
163
+ except (OSError, FileNotFoundError) as e:
164
+ logger.error(f"Could not load desktop_callback.html: {e}")
165
+
166
+ class CallbackHandler(BaseHTTPRequestHandler):
167
+ def do_GET(handler_self) -> None: # noqa: N802, N805
168
+ nonlocal authorization_code, server_error
169
+
170
+ parsed_url = urlparse(handler_self.path)
171
+
172
+ if parsed_url.path == "/callback":
173
+ params = parse_qs(parsed_url.query)
174
+
175
+ # Verify state parameter
176
+ received_state = params.get("state", [None])[0]
177
+ if received_state != self._state:
178
+ server_error = "State mismatch - possible CSRF attack"
179
+ logger.error(server_error)
180
+ handler_self.send_error(400, server_error)
181
+ return
182
+
183
+ # Check for errors
184
+ if "error" in params:
185
+ error = params["error"][0]
186
+ error_desc = params.get("error_description", [""])[0]
187
+ server_error = f"OAuth error: {error} - {error_desc}"
188
+ logger.error(server_error)
189
+ handler_self.send_error(400, server_error)
190
+ return
191
+
192
+ # Get authorization code
193
+ authorization_code = params.get("code", [None])[0]
194
+
195
+ if not authorization_code:
196
+ server_error = "No authorization code received"
197
+ logger.error(server_error)
198
+ handler_self.send_error(400, server_error)
199
+ return
200
+
201
+ # Send success response
202
+ handler_self.send_response(200)
203
+ handler_self.send_header("Content-type", "text/html")
204
+ handler_self.end_headers()
205
+
206
+ success_html = _get_callback_html()
207
+ handler_self.wfile.write(success_html.encode())
208
+ else:
209
+ handler_self.send_error(404, "Not found")
210
+
211
+ def log_message(self, *args: list[Any] | None) -> None:
212
+ pass # Suppress logging
213
+
214
+ # Start server
215
+ server = HTTPServer(("127.0.0.1", self.local_port), CallbackHandler)
216
+ server.timeout = self._callback_timeout
217
+ server.handle_request()
218
+ server.server_close()
219
+
220
+ if server_error:
221
+ logger.error(f"OAuth callback server error: {server_error}")
222
+ raise ValueError(server_error)
223
+
224
+ logger.trace("OAuth callback received successfully")
225
+ return authorization_code
226
+
227
+ def _exchange_code_for_tokens(self) -> dict:
228
+ """Exchange authorization code for tokens using the desktop endpoint.
229
+
230
+ Returns:
231
+ A dictionary containing the authorization code (JWT token).
232
+
233
+ Raises:
234
+ ValueError: If the authorization code is invalid.
235
+ """
236
+ token_url = f"{self.api_root}/auth/desktop/token"
237
+ data = {
238
+ "code": self._authorization_code,
239
+ "code_verifier": self._code_verifier,
240
+ "redirect_uri": f"http://127.0.0.1:{self.local_port}/callback",
241
+ }
242
+
243
+ response = requests.post(token_url, json=data, timeout=10)
244
+
245
+ if response.status_code != 200:
246
+ error_msg = f"Token exchange failed: {response.status_code}"
247
+ try:
248
+ error_data = response.json()
249
+ error_msg += f" - {error_data.get('error', '')}: {error_data.get('error_description', '')}"
250
+ except (json.JSONDecodeError, ValueError):
251
+ error_msg += f" - {response.text}"
252
+ logger.error(error_msg)
253
+ raise ValueError(error_msg)
254
+
255
+ logger.debug("Successfully exchanged authorization code for tokens")
256
+ return response.json()
257
+
258
+ @property
259
+ def client_type(self) -> str:
260
+ """Get the client type for desktop authentication strategy."""
261
+ return "desktop"
262
+
263
+ # refresh_token method is now inherited from AuthStrategy base class
@@ -0,0 +1,94 @@
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>Authentication Successful</title>
5
+ <style>
6
+ @import url('https://fonts.googleapis.com/css2?family=Nimbus+Sans+L:wght@300;700&display=swap');
7
+
8
+ body {
9
+ font-family: 'Nimbus Sans L', 'Nimbus Sans Light', -apple-system, system-ui, sans-serif;
10
+ font-weight: 300;
11
+ display: flex;
12
+ justify-content: center;
13
+ align-items: center;
14
+ height: 100vh;
15
+ margin: 0;
16
+ background: linear-gradient(135deg, #82080B 0%, #B60C0F 100%);
17
+ color: #4F4D4D;
18
+ }
19
+
20
+ .container {
21
+ background: #F2F1ED;
22
+ padding: 20px 50px 40px 50px;
23
+ border-radius: 12px;
24
+ box-shadow: 0 15px 35px rgba(130, 8, 11, 0.3);
25
+ text-align: center;
26
+ max-width: 400px;
27
+ border: 1px solid #C9C7BE;
28
+ }
29
+
30
+ h1 {
31
+ font-family: 'Nimbus Sans L', 'Nimbus Sans Bold', -apple-system, system-ui, sans-serif;
32
+ font-weight: 700;
33
+ color: #4F4D4D;
34
+ font-size: clamp(2.2rem, 6vw, 2.5rem);
35
+ margin-bottom: 16px;
36
+ letter-spacing: -0.02em;
37
+ }
38
+
39
+ p {
40
+ color: #4F4D4D;
41
+ font-size: clamp(1.4rem, 5vw, 1.5rem);
42
+ line-height: 1.4;
43
+ margin: 0;
44
+ font-weight: 400;
45
+ }
46
+
47
+ /* Mobile-first responsive adjustments */
48
+ .container {
49
+ margin: 20px;
50
+ }
51
+
52
+ @media (max-width: 480px) {
53
+ .container {
54
+ padding: 30px 20px;
55
+ margin: 16px;
56
+ max-width: calc(100vw - 32px);
57
+ }
58
+
59
+ h1 {
60
+ margin-bottom: 20px;
61
+ }
62
+ }
63
+
64
+ @media (min-width: 481px) and (max-width: 768px) {
65
+ .container {
66
+ padding: 35px 30px;
67
+ }
68
+ }
69
+
70
+ /* Subtle animation for better UX */
71
+ .container {
72
+ animation: fadeInUp 0.6s ease-out;
73
+ }
74
+
75
+ @keyframes fadeInUp {
76
+ from {
77
+ opacity: 0;
78
+ transform: translateY(20px);
79
+ }
80
+ to {
81
+ opacity: 1;
82
+ transform: translateY(0);
83
+ }
84
+ }
85
+ </style>
86
+ </head>
87
+ <body>
88
+ <div class="container">
89
+ <h1>Authentication Successful!</h1>
90
+ <p>You can now close this window and return to your application.</p>
91
+ </div>
92
+ <script>setTimeout(() => window.close(), 2000);</script>
93
+ </body>
94
+ </html>