pdd-cli 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pdd-cli might be problematic. Click here for more details.

pdd/get_jwt_token.py ADDED
@@ -0,0 +1,290 @@
1
+ import asyncio
2
+ import time
3
+ from typing import Dict, Optional, Tuple
4
+
5
+ import keyring
6
+ import requests
7
+
8
+ # Custom exception classes for better error handling
9
+ class AuthError(Exception):
10
+ """Base class for authentication errors."""
11
+ pass
12
+
13
+ class NetworkError(Exception):
14
+ """Raised for network connectivity issues."""
15
+ pass
16
+
17
+ class TokenError(Exception):
18
+ """Raised for errors during token exchange or refresh."""
19
+ pass
20
+
21
+ class UserCancelledError(AuthError):
22
+ """Raised when the user cancels the authentication process."""
23
+ pass
24
+
25
+ class RateLimitError(AuthError):
26
+ """Raised when rate limits are exceeded."""
27
+ pass
28
+
29
+ class DeviceFlow:
30
+ """
31
+ Handles the GitHub Device Flow authentication process.
32
+ """
33
+
34
+ def __init__(self, client_id: str):
35
+ self.client_id = client_id
36
+ self.device_code_url = "https://github.com/login/device/code"
37
+ self.access_token_url = "https://github.com/login/oauth/access_token"
38
+ self.scope = "repo,user" # Adjust scopes as needed
39
+
40
+ async def request_device_code(self) -> Dict:
41
+ """
42
+ Requests a device code from GitHub.
43
+
44
+ Returns:
45
+ Dict: Response from GitHub containing device code, user code, etc.
46
+
47
+ Raises:
48
+ NetworkError: If there's a network issue.
49
+ AuthError: If GitHub returns an error.
50
+ """
51
+ try:
52
+ response = requests.post(
53
+ self.device_code_url,
54
+ headers={"Accept": "application/json"},
55
+ data={"client_id": self.client_id, "scope": self.scope},
56
+ timeout=10
57
+ )
58
+ response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
59
+ return response.json()
60
+ except requests.exceptions.ConnectionError as e:
61
+ raise NetworkError(f"Failed to connect to GitHub: {e}")
62
+ except requests.exceptions.RequestException as e:
63
+ raise AuthError(f"Error requesting device code: {e}")
64
+
65
+ async def poll_for_token(self, device_code: str, interval: int, expires_in: int) -> str:
66
+ """
67
+ Polls GitHub for the access token until the user authenticates or the code expires.
68
+
69
+ Args:
70
+ device_code: The device code obtained from request_device_code.
71
+ interval: The polling interval in seconds.
72
+ expires_in: The time in seconds until the device code expires.
73
+
74
+ Returns:
75
+ str: The GitHub access token.
76
+
77
+ Raises:
78
+ NetworkError: If there's a network issue.
79
+ AuthError: If the user doesn't authenticate in time or cancels.
80
+ TokenError: If there's an error exchanging the code for a token.
81
+ """
82
+ start_time = time.time()
83
+ while time.time() - start_time < expires_in:
84
+ try:
85
+ response = requests.post(
86
+ self.access_token_url,
87
+ headers={"Accept": "application/json"},
88
+ data={
89
+ "client_id": self.client_id,
90
+ "device_code": device_code,
91
+ "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
92
+ },
93
+ timeout=10
94
+ )
95
+ response.raise_for_status()
96
+ data = response.json()
97
+
98
+ if "error" in data:
99
+ if data["error"] == "authorization_pending":
100
+ await asyncio.sleep(interval)
101
+ elif data["error"] == "slow_down":
102
+ await asyncio.sleep(data["interval"])
103
+ elif data["error"] == "expired_token":
104
+ raise AuthError("Device code expired.")
105
+ elif data["error"] == "access_denied":
106
+ raise UserCancelledError("User denied access.")
107
+ else:
108
+ raise AuthError(f"GitHub authentication error: {data['error']}")
109
+ else:
110
+ return data["access_token"]
111
+ except requests.exceptions.ConnectionError as e:
112
+ raise NetworkError(f"Failed to connect to GitHub: {e}")
113
+ except requests.exceptions.RequestException as e:
114
+ raise TokenError(f"Error exchanging device code for token: {e}")
115
+
116
+ raise AuthError("Authentication timed out.")
117
+
118
+ class FirebaseAuthenticator:
119
+ """
120
+ Handles Firebase authentication and token management.
121
+ """
122
+
123
+ def __init__(self, firebase_api_key: str, app_name: str):
124
+ self.firebase_api_key = firebase_api_key
125
+ self.app_name = app_name
126
+ self.keyring_service_name = f"firebase-auth-{app_name}"
127
+ self.keyring_user_name = "refresh_token"
128
+
129
+ def _store_refresh_token(self, refresh_token: str):
130
+ """Stores the Firebase refresh token in the system keyring."""
131
+ keyring.set_password(self.keyring_service_name, self.keyring_user_name, refresh_token)
132
+
133
+ def _get_stored_refresh_token(self) -> Optional[str]:
134
+ """Retrieves the Firebase refresh token from the system keyring."""
135
+ return keyring.get_password(self.keyring_service_name, self.keyring_user_name)
136
+
137
+ def _delete_stored_refresh_token(self):
138
+ """Deletes the stored Firebase refresh token from the keyring."""
139
+ try:
140
+ keyring.delete_password(self.keyring_service_name, self.keyring_user_name)
141
+ except keyring.errors.NoKeyringError:
142
+ print("No keyring found. Token deletion skipped.")
143
+ except keyring.errors.PasswordDeleteError:
144
+ print("Failed to delete token from keyring.")
145
+
146
+ async def _refresh_firebase_token(self, refresh_token: str) -> str:
147
+ """
148
+ Refreshes the Firebase ID token using the refresh token.
149
+
150
+ Args:
151
+ refresh_token: The Firebase refresh token.
152
+
153
+ Returns:
154
+ str: The new Firebase ID token.
155
+
156
+ Raises:
157
+ NetworkError: If there's a network issue.
158
+ TokenError: If the refresh token is invalid or there's an error.
159
+ """
160
+ try:
161
+ response = requests.post(
162
+ f"https://securetoken.googleapis.com/v1/token?key={self.firebase_api_key}",
163
+ data={
164
+ "grant_type": "refresh_token",
165
+ "refresh_token": refresh_token,
166
+ },
167
+ timeout=10
168
+ )
169
+ response.raise_for_status()
170
+ data = response.json()
171
+ new_refresh_token = data["refresh_token"]
172
+ self._store_refresh_token(new_refresh_token)
173
+ return data["id_token"]
174
+ except requests.exceptions.ConnectionError as e:
175
+ raise NetworkError(f"Failed to connect to Firebase: {e}")
176
+ except requests.exceptions.RequestException as e:
177
+ if e.response and e.response.status_code == 400:
178
+ error_data = e.response.json()
179
+ if error_data.get("error", {}).get("message") == "INVALID_REFRESH_TOKEN":
180
+ self._delete_stored_refresh_token()
181
+ raise TokenError("Invalid or expired refresh token. Please re-authenticate.")
182
+ elif error_data.get("error", {}).get("message") == "TOO_MANY_ATTEMPTS_TRY_LATER":
183
+ raise RateLimitError("Too many refresh attempts. Please try again later.")
184
+ else:
185
+ raise TokenError(f"Error refreshing Firebase token: {e}")
186
+ else:
187
+ raise TokenError(f"Error refreshing Firebase token: {e}")
188
+
189
+ async def exchange_github_token_for_firebase_token(self, github_token: str) -> Tuple[str, str]:
190
+ """
191
+ Exchanges a GitHub access token for a Firebase ID token and refresh token.
192
+
193
+ Args:
194
+ github_token: The GitHub access token.
195
+
196
+ Returns:
197
+ Tuple[str, str]: The Firebase ID token and refresh token.
198
+
199
+ Raises:
200
+ NetworkError: If there's a network issue.
201
+ TokenError: If the token exchange fails.
202
+ """
203
+ try:
204
+ response = requests.post(
205
+ f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithIdp?key={self.firebase_api_key}",
206
+ data={
207
+ "requestUri": "http://localhost", # Required by Firebase, but not used in Device Flow
208
+ "returnSecureToken": True,
209
+ "postBody": f"access_token={github_token}&providerId=github.com",
210
+ },
211
+ timeout=10
212
+ )
213
+ response.raise_for_status()
214
+ data = response.json()
215
+ return data["idToken"], data["refreshToken"]
216
+ except requests.exceptions.ConnectionError as e:
217
+ raise NetworkError(f"Failed to connect to Firebase: {e}")
218
+ except requests.exceptions.RequestException as e:
219
+ raise TokenError(f"Error exchanging GitHub token for Firebase token: {e}")
220
+
221
+ def verify_firebase_token(self, id_token: str) -> bool:
222
+ """
223
+ Verifies the Firebase ID token.
224
+
225
+ Note: This is a simplified verification that only checks if the token exists.
226
+ For production use, implement proper token verification.
227
+ """
228
+ return bool(id_token)
229
+
230
+ async def get_jwt_token(firebase_api_key: str, github_client_id: str, app_name: str = "my-cli-app") -> str:
231
+ """
232
+ Get a Firebase ID token using GitHub's Device Flow authentication.
233
+
234
+ Args:
235
+ firebase_api_key: Firebase Web API key
236
+ github_client_id: OAuth client ID for GitHub app
237
+ app_name: Unique name for your CLI application
238
+
239
+ Returns:
240
+ str: A valid Firebase ID token
241
+
242
+ Raises:
243
+ AuthError: If authentication fails
244
+ NetworkError: If there are connectivity issues
245
+ TokenError: If token exchange fails
246
+ """
247
+ firebase_auth = FirebaseAuthenticator(firebase_api_key, app_name)
248
+
249
+ # Check for existing refresh token
250
+ refresh_token = firebase_auth._get_stored_refresh_token()
251
+ if refresh_token:
252
+ try:
253
+ # Attempt to refresh the token
254
+ id_token = await firebase_auth._refresh_firebase_token(refresh_token)
255
+ if firebase_auth.verify_firebase_token(id_token):
256
+ return id_token
257
+ else:
258
+ print("Refreshed token is invalid. Attempting re-authentication.")
259
+ firebase_auth._delete_stored_refresh_token()
260
+ except (NetworkError, TokenError, RateLimitError) as e:
261
+ print(f"Token refresh failed: {e}")
262
+ if not isinstance(e, RateLimitError):
263
+ firebase_auth._delete_stored_refresh_token()
264
+ if isinstance(e, RateLimitError):
265
+ raise
266
+ print("Attempting re-authentication...")
267
+
268
+ # Initiate Device Flow
269
+ device_flow = DeviceFlow(github_client_id)
270
+ device_code_response = await device_flow.request_device_code()
271
+
272
+ # Display instructions to the user
273
+ print(f"To authenticate, visit: {device_code_response['verification_uri']}")
274
+ print(f"Enter code: {device_code_response['user_code']}")
275
+ print("Waiting for authentication...")
276
+
277
+ # Poll for GitHub token
278
+ github_token = await device_flow.poll_for_token(
279
+ device_code_response["device_code"],
280
+ device_code_response["interval"],
281
+ device_code_response["expires_in"],
282
+ )
283
+
284
+ # Exchange GitHub token for Firebase token
285
+ id_token, refresh_token = await firebase_auth.exchange_github_token_for_firebase_token(github_token)
286
+
287
+ # Store refresh token
288
+ firebase_auth._store_refresh_token(refresh_token)
289
+
290
+ return id_token
@@ -0,0 +1,136 @@
1
+ import os
2
+ import sys
3
+ import importlib.resources
4
+ from typing import Optional
5
+
6
+ import click
7
+ from rich import print as rprint
8
+
9
+ # ----------------------------------------------------------------------
10
+ # Dynamically determine PDD_PATH at runtime.
11
+ # ----------------------------------------------------------------------
12
+ def get_local_pdd_path() -> str:
13
+ """
14
+ Return the PDD_PATH directory.
15
+ First check the environment variable. If not set, attempt to
16
+ deduce it via importlib.resources. If that fails, abort.
17
+ """
18
+ if "PDD_PATH" in os.environ:
19
+ return os.environ["PDD_PATH"]
20
+ else:
21
+ try:
22
+ with importlib.resources.path("pdd", "cli.py") as p:
23
+ fallback_path = str(p.parent)
24
+ # Also set it back into the environment for consistency
25
+ os.environ["PDD_PATH"] = fallback_path
26
+ return fallback_path
27
+ except ImportError:
28
+ rprint(
29
+ "[red]Error: Could not determine the path to the 'pdd' package. "
30
+ "Please set the PDD_PATH environment variable manually.[/red]"
31
+ )
32
+ sys.exit(1)
33
+
34
+ # ----------------------------------------------------------------------
35
+ # Simplified shell RC path logic
36
+ # ----------------------------------------------------------------------
37
+ def get_shell_rc_path(shell: str) -> Optional[str]:
38
+ """Return the default RC file path for a given shell name."""
39
+ home = os.path.expanduser("~")
40
+ if shell == "bash":
41
+ return os.path.join(home, ".bashrc")
42
+ elif shell == "zsh":
43
+ return os.path.join(home, ".zshrc")
44
+ elif shell == "fish":
45
+ return os.path.join(home, ".config", "fish", "config.fish")
46
+ return None
47
+
48
+
49
+ def get_current_shell() -> Optional[str]:
50
+
51
+
52
+ """Determine the currently running shell more reliably."""
53
+ if not os.environ.get('PYTEST_CURRENT_TEST'):
54
+ # Method 1: Check process name using 'ps'
55
+ try:
56
+ import subprocess
57
+ result = subprocess.run(['ps', '-p', str(os.getppid()), '-o', 'comm='],
58
+ capture_output=True, text=True)
59
+ if result.returncode == 0:
60
+ # Strip whitespace and get basename without path
61
+ shell = os.path.basename(result.stdout.strip())
62
+ # Remove leading dash if present (login shell)
63
+ return shell.lstrip('-')
64
+ except (subprocess.SubprocessError, FileNotFoundError):
65
+ pass
66
+
67
+ # Method 2: Check $0 special parameter
68
+ try:
69
+ result = subprocess.run(['sh', '-c', 'echo "$0"'],
70
+ capture_output=True, text=True)
71
+ if result.returncode == 0:
72
+ shell = os.path.basename(result.stdout.strip())
73
+ return shell.lstrip('-')
74
+ except (subprocess.SubprocessError, FileNotFoundError):
75
+ pass
76
+
77
+ # Fallback to SHELL env var if all else fails
78
+ return os.path.basename(os.environ.get("SHELL", ""))
79
+
80
+
81
+ def get_completion_script_extension(shell: str) -> str:
82
+ """Get the appropriate file extension for shell completion scripts."""
83
+ mapping = {
84
+ "bash": "sh",
85
+ "zsh": "zsh",
86
+ "fish": "fish"
87
+ }
88
+ return mapping.get(shell, shell)
89
+
90
+
91
+ def install_completion():
92
+ """
93
+ Install shell completion for the PDD CLI by detecting the user’s shell,
94
+ copying the relevant completion script, and appending a source command
95
+ to the user’s shell RC file if not already present.
96
+ """
97
+ shell = get_current_shell()
98
+ rc_file = get_shell_rc_path(shell)
99
+ if not rc_file:
100
+ rprint(f"[red]Unsupported shell: {shell}[/red]")
101
+ raise click.Abort()
102
+
103
+ ext = get_completion_script_extension(shell)
104
+
105
+ # Dynamically look up the local path at runtime:
106
+ local_pdd_path = get_local_pdd_path()
107
+ completion_script_path = os.path.join(local_pdd_path, f"pdd_completion.{ext}")
108
+
109
+ if not os.path.exists(completion_script_path):
110
+ rprint(f"[red]Completion script not found: {completion_script_path}[/red]")
111
+ raise click.Abort()
112
+
113
+ source_command = f"source {completion_script_path}"
114
+
115
+ try:
116
+ # Ensure the RC file exists (create if missing).
117
+ if not os.path.exists(rc_file):
118
+ os.makedirs(os.path.dirname(rc_file), exist_ok=True)
119
+ with open(rc_file, "w", encoding="utf-8") as cf:
120
+ cf.write("")
121
+
122
+ # Read existing content
123
+ with open(rc_file, "r", encoding="utf-8") as cf:
124
+ content = cf.read()
125
+
126
+ if source_command not in content:
127
+ with open(rc_file, "a", encoding="utf-8") as rf:
128
+ rf.write(f"\n# PDD CLI completion\n{source_command}\n")
129
+
130
+ rprint(f"[green]Shell completion installed for {shell}.[/green]")
131
+ rprint(f"Please restart your shell or run 'source {rc_file}' to enable completion.")
132
+ else:
133
+ rprint(f"[yellow]Shell completion already installed for {shell}.[/yellow]")
134
+ except OSError as exc:
135
+ rprint(f"[red]Failed to install shell completion: {exc}[/red]")
136
+ raise click.Abort()