pakt 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pakt/trakt.py ADDED
@@ -0,0 +1,575 @@
1
+ """Trakt API client with batch operations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import time
7
+ from dataclasses import dataclass
8
+ from enum import Enum
9
+ from typing import Any, Callable
10
+
11
+ import httpx
12
+ from rich.console import Console
13
+
14
+ from pakt.config import TraktConfig
15
+ from pakt.models import RatedItem, TraktIds, WatchedItem
16
+
17
+ console = Console()
18
+
19
+ TRAKT_API_URL = "https://api.trakt.tv"
20
+ TRAKT_AUTH_URL = "https://trakt.tv"
21
+
22
+ # Refresh token 1 day before expiry (tokens valid for 7 days)
23
+ TOKEN_REFRESH_THRESHOLD = 24 * 60 * 60
24
+
25
+
26
+ class TraktRateLimitError(Exception):
27
+ """Raised when rate limited."""
28
+
29
+ def __init__(self, retry_after: int):
30
+ self.retry_after = retry_after
31
+ super().__init__(f"Rate limited, retry after {retry_after}s")
32
+
33
+
34
+ class TraktAccountLimitError(Exception):
35
+ """Raised when account limit exceeded (HTTP 420)."""
36
+
37
+ def __init__(self, limit: int, is_vip: bool, upgrade_url: str = "https://trakt.tv/vip"):
38
+ self.limit = limit
39
+ self.is_vip = is_vip
40
+ self.upgrade_url = upgrade_url
41
+ msg = f"Account limit exceeded ({limit} items)"
42
+ if not is_vip:
43
+ msg += f". Upgrade to VIP: {upgrade_url}"
44
+ super().__init__(msg)
45
+
46
+
47
+ @dataclass
48
+ class AccountLimits:
49
+ """User account limits from Trakt."""
50
+
51
+ is_vip: bool
52
+ collection_limit: int
53
+ watchlist_limit: int
54
+ list_limit: int
55
+ list_item_limit: int
56
+
57
+
58
+ class DeviceAuthStatus(Enum):
59
+ """Device authentication polling status."""
60
+ SUCCESS = "success"
61
+ PENDING = "pending"
62
+ INVALID_CODE = "invalid_code"
63
+ EXPIRED = "expired"
64
+ DENIED = "denied"
65
+ RATE_LIMITED = "rate_limited"
66
+
67
+
68
+ @dataclass
69
+ class DeviceAuthResult:
70
+ """Result of device authentication polling."""
71
+ status: DeviceAuthStatus
72
+ token: dict[str, Any] | None = None
73
+ message: str = ""
74
+
75
+
76
+ class TraktClient:
77
+ """Async Trakt API client optimized for batch operations."""
78
+
79
+ def __init__(
80
+ self,
81
+ config: TraktConfig,
82
+ on_token_refresh: Callable[[dict[str, Any]], None] | None = None,
83
+ ):
84
+ self.config = config
85
+ self._client: httpx.AsyncClient | None = None
86
+ self._on_token_refresh = on_token_refresh
87
+
88
+ async def __aenter__(self) -> TraktClient:
89
+ self._client = httpx.AsyncClient(
90
+ base_url=TRAKT_API_URL,
91
+ timeout=30.0,
92
+ headers=self._headers,
93
+ )
94
+ # Check if token needs refresh on entry
95
+ await self._ensure_valid_token()
96
+ return self
97
+
98
+ async def __aexit__(self, *args) -> None:
99
+ if self._client:
100
+ await self._client.aclose()
101
+
102
+ def _token_needs_refresh(self) -> bool:
103
+ """Check if access token needs refresh (within threshold of expiry)."""
104
+ if not self.config.access_token or not self.config.refresh_token:
105
+ return False
106
+ if not self.config.expires_at:
107
+ return False
108
+ return time.time() >= (self.config.expires_at - TOKEN_REFRESH_THRESHOLD)
109
+
110
+ async def _ensure_valid_token(self) -> None:
111
+ """Refresh access token if it's about to expire."""
112
+ if not self._token_needs_refresh():
113
+ return
114
+
115
+ console.print("[cyan]Refreshing Trakt access token...[/]")
116
+ try:
117
+ token = await self.refresh_access_token()
118
+ self.config.access_token = token["access_token"]
119
+ self.config.refresh_token = token["refresh_token"]
120
+ self.config.expires_at = token["created_at"] + token["expires_in"]
121
+
122
+ # Update client headers with new token
123
+ if self._client:
124
+ self._client.headers["Authorization"] = f"Bearer {token['access_token']}"
125
+
126
+ # Notify caller to persist the new tokens
127
+ if self._on_token_refresh:
128
+ self._on_token_refresh(token)
129
+
130
+ console.print("[green]Token refreshed successfully[/]")
131
+ except Exception as e:
132
+ console.print(f"[yellow]Token refresh failed: {e}[/]")
133
+
134
+ @property
135
+ def _headers(self) -> dict[str, str]:
136
+ headers = {
137
+ "Content-Type": "application/json",
138
+ "trakt-api-version": "2",
139
+ "trakt-api-key": self.config.client_id,
140
+ }
141
+ if self.config.access_token:
142
+ headers["Authorization"] = f"Bearer {self.config.access_token}"
143
+ return headers
144
+
145
+ async def _request(
146
+ self,
147
+ method: str,
148
+ path: str,
149
+ retries: int = 3,
150
+ **kwargs,
151
+ ) -> httpx.Response:
152
+ """Make a request with rate limit handling."""
153
+ if not self._client:
154
+ raise RuntimeError("Client not initialized. Use async with.")
155
+
156
+ for attempt in range(retries):
157
+ try:
158
+ response = await self._client.request(method, path, **kwargs)
159
+
160
+ if response.status_code == 429:
161
+ retry_after = int(response.headers.get("Retry-After", 60))
162
+ console.print(
163
+ f"[yellow]Rate limited, waiting {retry_after}s "
164
+ f"(attempt {attempt + 1}/{retries})[/]"
165
+ )
166
+ await asyncio.sleep(retry_after + 1)
167
+ continue
168
+
169
+ # Handle account limit exceeded (non-VIP limit)
170
+ if response.status_code == 420:
171
+ is_vip = response.headers.get("X-VIP-User", "false").lower() == "true"
172
+ limit = int(response.headers.get("X-Account-Limit", 100))
173
+ upgrade_url = response.headers.get("X-Upgrade-URL", "https://trakt.tv/vip")
174
+ raise TraktAccountLimitError(limit, is_vip, upgrade_url)
175
+
176
+ response.raise_for_status()
177
+ return response
178
+
179
+ except httpx.HTTPStatusError as e:
180
+ if e.response.status_code == 429 and attempt < retries - 1:
181
+ continue
182
+ raise
183
+
184
+ raise TraktRateLimitError(60)
185
+
186
+ # =========================================================================
187
+ # BATCH READ OPERATIONS - Single call gets everything
188
+ # =========================================================================
189
+
190
+ async def get_watched_movies(self) -> list[WatchedItem]:
191
+ """Get ALL watched movies in a single API call."""
192
+ response = await self._request("GET", "/sync/watched/movies")
193
+ return [WatchedItem(**item) for item in response.json()]
194
+
195
+ async def get_watched_shows(self) -> list[WatchedItem]:
196
+ """Get ALL watched shows in a single API call."""
197
+ response = await self._request("GET", "/sync/watched/shows")
198
+ return [WatchedItem(**item) for item in response.json()]
199
+
200
+ async def get_movie_ratings(self) -> list[RatedItem]:
201
+ """Get ALL movie ratings in a single API call."""
202
+ response = await self._request("GET", "/sync/ratings/movies")
203
+ return [RatedItem(**item) for item in response.json()]
204
+
205
+ async def get_show_ratings(self) -> list[RatedItem]:
206
+ """Get ALL show ratings in a single API call."""
207
+ response = await self._request("GET", "/sync/ratings/shows")
208
+ return [RatedItem(**item) for item in response.json()]
209
+
210
+ async def get_episode_ratings(self) -> list[RatedItem]:
211
+ """Get ALL episode ratings in a single API call."""
212
+ response = await self._request("GET", "/sync/ratings/episodes")
213
+ return [RatedItem(**item) for item in response.json()]
214
+
215
+ async def get_collection_movies(self) -> list[dict[str, Any]]:
216
+ """Get ALL collected movies in a single API call."""
217
+ response = await self._request("GET", "/sync/collection/movies")
218
+ return response.json()
219
+
220
+ async def get_collection_shows(self) -> list[dict[str, Any]]:
221
+ """Get ALL collected shows in a single API call."""
222
+ response = await self._request("GET", "/sync/collection/shows")
223
+ return response.json()
224
+
225
+ async def get_watchlist_movies(self) -> list[dict[str, Any]]:
226
+ """Get ALL watchlist movies in a single API call."""
227
+ response = await self._request("GET", "/sync/watchlist/movies")
228
+ return response.json()
229
+
230
+ async def get_watchlist_shows(self) -> list[dict[str, Any]]:
231
+ """Get ALL watchlist shows in a single API call."""
232
+ response = await self._request("GET", "/sync/watchlist/shows")
233
+ return response.json()
234
+
235
+ async def get_user_settings(self) -> dict[str, Any]:
236
+ """Get user settings including VIP status and account limits."""
237
+ response = await self._request("GET", "/users/settings")
238
+ return response.json()
239
+
240
+ async def get_account_limits(self) -> AccountLimits:
241
+ """Get account limits for the authenticated user.
242
+
243
+ Returns:
244
+ AccountLimits with VIP status and various limits.
245
+ Non-VIP users typically have 100 item limits.
246
+ """
247
+ settings = await self.get_user_settings()
248
+ user = settings.get("user", {})
249
+ limits = settings.get("limits", {})
250
+
251
+ return AccountLimits(
252
+ is_vip=user.get("vip", False),
253
+ collection_limit=limits.get("collection", {}).get("item_count", 100),
254
+ watchlist_limit=limits.get("watchlist", {}).get("item_count", 100),
255
+ list_limit=limits.get("list", {}).get("count", 2),
256
+ list_item_limit=limits.get("list", {}).get("item_count", 100),
257
+ )
258
+
259
+ # =========================================================================
260
+ # BATCH WRITE OPERATIONS - Single call updates everything
261
+ # =========================================================================
262
+
263
+ async def add_to_history(
264
+ self,
265
+ movies: list[dict] | None = None,
266
+ shows: list[dict] | None = None,
267
+ episodes: list[dict] | None = None,
268
+ ) -> dict[str, Any]:
269
+ """Add multiple items to watch history in a single call."""
270
+ payload = {}
271
+ if movies:
272
+ payload["movies"] = movies
273
+ if shows:
274
+ payload["shows"] = shows
275
+ if episodes:
276
+ payload["episodes"] = episodes
277
+
278
+ if not payload:
279
+ return {"added": {"movies": 0, "episodes": 0}}
280
+
281
+ response = await self._request("POST", "/sync/history", json=payload)
282
+ return response.json()
283
+
284
+ async def remove_from_history(
285
+ self,
286
+ movies: list[dict] | None = None,
287
+ shows: list[dict] | None = None,
288
+ episodes: list[dict] | None = None,
289
+ ) -> dict[str, Any]:
290
+ """Remove multiple items from watch history in a single call."""
291
+ payload = {}
292
+ if movies:
293
+ payload["movies"] = movies
294
+ if shows:
295
+ payload["shows"] = shows
296
+ if episodes:
297
+ payload["episodes"] = episodes
298
+
299
+ if not payload:
300
+ return {"deleted": {"movies": 0, "episodes": 0}}
301
+
302
+ response = await self._request("POST", "/sync/history/remove", json=payload)
303
+ return response.json()
304
+
305
+ async def add_ratings(
306
+ self,
307
+ movies: list[dict] | None = None,
308
+ shows: list[dict] | None = None,
309
+ episodes: list[dict] | None = None,
310
+ ) -> dict[str, Any]:
311
+ """Add/update multiple ratings in a single call."""
312
+ payload = {}
313
+ if movies:
314
+ payload["movies"] = movies
315
+ if shows:
316
+ payload["shows"] = shows
317
+ if episodes:
318
+ payload["episodes"] = episodes
319
+
320
+ if not payload:
321
+ return {"added": {"movies": 0, "shows": 0, "episodes": 0}}
322
+
323
+ response = await self._request("POST", "/sync/ratings", json=payload)
324
+ return response.json()
325
+
326
+ async def remove_ratings(
327
+ self,
328
+ movies: list[dict] | None = None,
329
+ shows: list[dict] | None = None,
330
+ episodes: list[dict] | None = None,
331
+ ) -> dict[str, Any]:
332
+ """Remove multiple ratings in a single call."""
333
+ payload = {}
334
+ if movies:
335
+ payload["movies"] = movies
336
+ if shows:
337
+ payload["shows"] = shows
338
+ if episodes:
339
+ payload["episodes"] = episodes
340
+
341
+ if not payload:
342
+ return {"deleted": {"movies": 0, "shows": 0, "episodes": 0}}
343
+
344
+ response = await self._request("POST", "/sync/ratings/remove", json=payload)
345
+ return response.json()
346
+
347
+ async def add_to_collection(
348
+ self,
349
+ movies: list[dict] | None = None,
350
+ shows: list[dict] | None = None,
351
+ ) -> dict[str, Any]:
352
+ """Add items to collection with optional metadata."""
353
+ payload = {}
354
+ if movies:
355
+ payload["movies"] = movies
356
+ if shows:
357
+ payload["shows"] = shows
358
+
359
+ if not payload:
360
+ return {"added": {"movies": 0, "shows": 0}}
361
+
362
+ response = await self._request("POST", "/sync/collection", json=payload)
363
+ return response.json()
364
+
365
+ async def remove_from_collection(
366
+ self,
367
+ movies: list[dict] | None = None,
368
+ shows: list[dict] | None = None,
369
+ ) -> dict[str, Any]:
370
+ """Remove items from collection."""
371
+ payload = {}
372
+ if movies:
373
+ payload["movies"] = movies
374
+ if shows:
375
+ payload["shows"] = shows
376
+
377
+ if not payload:
378
+ return {"deleted": {"movies": 0, "shows": 0}}
379
+
380
+ response = await self._request("POST", "/sync/collection/remove", json=payload)
381
+ return response.json()
382
+
383
+ async def add_to_watchlist(
384
+ self,
385
+ movies: list[dict] | None = None,
386
+ shows: list[dict] | None = None,
387
+ ) -> dict[str, Any]:
388
+ """Add items to watchlist."""
389
+ payload = {}
390
+ if movies:
391
+ payload["movies"] = movies
392
+ if shows:
393
+ payload["shows"] = shows
394
+
395
+ if not payload:
396
+ return {"added": {"movies": 0, "shows": 0}}
397
+
398
+ response = await self._request("POST", "/sync/watchlist", json=payload)
399
+ return response.json()
400
+
401
+ async def remove_from_watchlist(
402
+ self,
403
+ movies: list[dict] | None = None,
404
+ shows: list[dict] | None = None,
405
+ ) -> dict[str, Any]:
406
+ """Remove items from watchlist."""
407
+ payload = {}
408
+ if movies:
409
+ payload["movies"] = movies
410
+ if shows:
411
+ payload["shows"] = shows
412
+
413
+ if not payload:
414
+ return {"deleted": {"movies": 0, "shows": 0}}
415
+
416
+ response = await self._request("POST", "/sync/watchlist/remove", json=payload)
417
+ return response.json()
418
+
419
+ # =========================================================================
420
+ # SEARCH - For ID lookups (cached heavily)
421
+ # =========================================================================
422
+
423
+ async def search_by_id(
424
+ self,
425
+ id_type: str,
426
+ media_id: str,
427
+ media_type: str | None = None,
428
+ ) -> list[dict[str, Any]]:
429
+ """Search for an item by external ID."""
430
+ params = {"id_type": id_type}
431
+ if media_type:
432
+ params["type"] = media_type
433
+
434
+ response = await self._request("GET", f"/search/{id_type}/{media_id}", params=params)
435
+ return response.json()
436
+
437
+ # =========================================================================
438
+ # AUTHENTICATION
439
+ # =========================================================================
440
+
441
+ async def device_code(self) -> dict[str, Any]:
442
+ """Start device authentication flow."""
443
+ async with httpx.AsyncClient() as client:
444
+ response = await client.post(
445
+ f"{TRAKT_API_URL}/oauth/device/code",
446
+ json={"client_id": self.config.client_id},
447
+ headers={"Content-Type": "application/json"},
448
+ )
449
+ response.raise_for_status()
450
+ return response.json()
451
+
452
+ async def poll_device_token(
453
+ self,
454
+ device_code: str,
455
+ interval: int = 5,
456
+ expires_in: int = 600,
457
+ ) -> DeviceAuthResult:
458
+ """Poll for device token after user authorizes.
459
+
460
+ Returns DeviceAuthResult with specific status:
461
+ - SUCCESS: Token obtained, ready to use
462
+ - PENDING: Still waiting for user authorization
463
+ - INVALID_CODE: Device code is invalid (404)
464
+ - EXPIRED: Device code has expired (410)
465
+ - DENIED: User explicitly denied authorization (418)
466
+ - RATE_LIMITED: Polling too fast (429)
467
+ """
468
+ start = time.time()
469
+ current_interval = interval
470
+
471
+ async with httpx.AsyncClient() as client:
472
+ while time.time() - start < expires_in:
473
+ response = await client.post(
474
+ f"{TRAKT_API_URL}/oauth/device/token",
475
+ json={
476
+ "code": device_code,
477
+ "client_id": self.config.client_id,
478
+ "client_secret": self.config.client_secret,
479
+ },
480
+ headers={"Content-Type": "application/json"},
481
+ )
482
+
483
+ if response.status_code == 200:
484
+ return DeviceAuthResult(
485
+ status=DeviceAuthStatus.SUCCESS,
486
+ token=response.json(),
487
+ message="Authentication successful",
488
+ )
489
+ elif response.status_code == 400:
490
+ # Pending authorization - keep polling
491
+ await asyncio.sleep(current_interval)
492
+ elif response.status_code == 404:
493
+ return DeviceAuthResult(
494
+ status=DeviceAuthStatus.INVALID_CODE,
495
+ message="Invalid device code. Please restart authentication.",
496
+ )
497
+ elif response.status_code == 409:
498
+ # Code already used - treat as success check
499
+ return DeviceAuthResult(
500
+ status=DeviceAuthStatus.INVALID_CODE,
501
+ message="Device code was already used.",
502
+ )
503
+ elif response.status_code == 410:
504
+ return DeviceAuthResult(
505
+ status=DeviceAuthStatus.EXPIRED,
506
+ message="Device code has expired. Please restart authentication.",
507
+ )
508
+ elif response.status_code == 418:
509
+ return DeviceAuthResult(
510
+ status=DeviceAuthStatus.DENIED,
511
+ message="User denied authorization.",
512
+ )
513
+ elif response.status_code == 429:
514
+ # Slow down polling
515
+ current_interval = min(current_interval * 2, 30)
516
+ await asyncio.sleep(current_interval)
517
+ else:
518
+ return DeviceAuthResult(
519
+ status=DeviceAuthStatus.INVALID_CODE,
520
+ message=f"Unexpected error: {response.status_code}",
521
+ )
522
+
523
+ return DeviceAuthResult(
524
+ status=DeviceAuthStatus.EXPIRED,
525
+ message="Polling timed out. Please restart authentication.",
526
+ )
527
+
528
+ async def refresh_access_token(self) -> dict[str, Any]:
529
+ """Refresh the access token."""
530
+ async with httpx.AsyncClient() as client:
531
+ response = await client.post(
532
+ f"{TRAKT_API_URL}/oauth/token",
533
+ json={
534
+ "refresh_token": self.config.refresh_token,
535
+ "client_id": self.config.client_id,
536
+ "client_secret": self.config.client_secret,
537
+ "grant_type": "refresh_token",
538
+ },
539
+ headers={"Content-Type": "application/json"},
540
+ )
541
+ response.raise_for_status()
542
+ return response.json()
543
+
544
+ async def revoke_token(self) -> bool:
545
+ """Revoke the current access token (logout).
546
+
547
+ Per Trakt API docs, should be called when user logs out to
548
+ invalidate the token on Trakt's side.
549
+ """
550
+ if not self.config.access_token:
551
+ return True
552
+
553
+ async with httpx.AsyncClient() as client:
554
+ response = await client.post(
555
+ f"{TRAKT_API_URL}/oauth/revoke",
556
+ json={
557
+ "token": self.config.access_token,
558
+ "client_id": self.config.client_id,
559
+ "client_secret": self.config.client_secret,
560
+ },
561
+ headers={"Content-Type": "application/json"},
562
+ )
563
+ return response.status_code == 200
564
+
565
+
566
+ def extract_trakt_ids(data: dict[str, Any]) -> TraktIds:
567
+ """Extract Trakt IDs from API response."""
568
+ ids = data.get("ids", {})
569
+ return TraktIds(
570
+ trakt=ids.get("trakt"),
571
+ slug=ids.get("slug"),
572
+ imdb=ids.get("imdb"),
573
+ tmdb=ids.get("tmdb"),
574
+ tvdb=ids.get("tvdb"),
575
+ )