marqetive-lib 0.1.14__tar.gz → 0.1.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/PKG-INFO +1 -1
  2. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/pyproject.toml +1 -1
  3. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/core/base.py +2 -2
  4. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/instagram/client.py +40 -8
  5. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/linkedin/client.py +75 -11
  6. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/tiktok/client.py +3 -9
  7. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/tiktok/media.py +3 -1
  8. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/twitter/client.py +93 -3
  9. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/utils/file_handlers.py +151 -13
  10. marqetive_lib-0.1.16/src/marqetive/utils/helpers.py +50 -0
  11. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/utils/media.py +2 -2
  12. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/utils/oauth.py +1 -2
  13. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/utils/retry.py +1 -2
  14. marqetive_lib-0.1.14/src/marqetive/utils/helpers.py +0 -99
  15. marqetive_lib-0.1.14/src/marqetive/utils/token_validator.py +0 -240
  16. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/README.md +0 -0
  17. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/__init__.py +0 -0
  18. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/core/__init__.py +0 -0
  19. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/core/client.py +0 -0
  20. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/core/exceptions.py +0 -0
  21. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/core/models.py +0 -0
  22. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/factory.py +0 -0
  23. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/__init__.py +0 -0
  24. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/instagram/__init__.py +0 -0
  25. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/instagram/exceptions.py +0 -0
  26. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/instagram/media.py +0 -0
  27. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/instagram/models.py +0 -0
  28. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/linkedin/__init__.py +0 -0
  29. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/linkedin/exceptions.py +0 -0
  30. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/linkedin/media.py +0 -0
  31. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/linkedin/models.py +0 -0
  32. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/tiktok/__init__.py +0 -0
  33. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/tiktok/exceptions.py +0 -0
  34. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/tiktok/models.py +0 -0
  35. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/twitter/__init__.py +0 -0
  36. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/twitter/exceptions.py +0 -0
  37. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/twitter/media.py +0 -0
  38. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/platforms/twitter/models.py +0 -0
  39. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/py.typed +0 -0
  40. {marqetive_lib-0.1.14 → marqetive_lib-0.1.16}/src/marqetive/utils/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: marqetive-lib
3
- Version: 0.1.14
3
+ Version: 0.1.16
4
4
  Summary: Modern Python utilities for web APIs
5
5
  Keywords: api,utilities,web,http,marqetive
6
6
  Requires-Python: >=3.12
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [project]
6
6
  name = "marqetive-lib"
7
- version = "0.1.14"
7
+ version = "0.1.16"
8
8
  description = "Modern Python utilities for web APIs"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.12"
@@ -9,7 +9,7 @@ import inspect
9
9
  from abc import ABC, abstractmethod
10
10
  from collections.abc import Awaitable, Callable
11
11
  from datetime import datetime
12
- from traceback import TracebackException
12
+ from types import TracebackType
13
13
  from typing import Any
14
14
 
15
15
  from marqetive.core.client import APIClient
@@ -119,7 +119,7 @@ class SocialMediaPlatform(ABC):
119
119
  return self
120
120
 
121
121
  async def __aexit__(
122
- self, exc_type: type[Exception], exc_val: Any, exc_tb: TracebackException
122
+ self, exc_type: type[Exception] | None, exc_val: Exception | None, exc_tb: TracebackType | None
123
123
  ) -> None:
124
124
  """Async context manager exit."""
125
125
  if self.api_client:
@@ -360,7 +360,16 @@ class InstagramClient(SocialMediaPlatform):
360
360
  )
361
361
 
362
362
  result = await self._media_manager.publish_container(container_ids[0])
363
- return await self.get_post(result.media_id)
363
+
364
+ # Return minimal Post object without fetching details
365
+ return Post(
366
+ post_id=result.media_id,
367
+ platform=self.platform_name,
368
+ content=request.content,
369
+ status=PostStatus.PUBLISHED,
370
+ created_at=datetime.now(),
371
+ url=cast(HttpUrl, result.permalink) if result.permalink else None,
372
+ )
364
373
 
365
374
  async def _create_carousel_post(self, request: PostCreateRequest) -> Post:
366
375
  """Create a carousel post (2-10 images).
@@ -789,7 +798,9 @@ class InstagramClient(SocialMediaPlatform):
789
798
  alt_text = None
790
799
  if alt_texts and idx < len(alt_texts):
791
800
  alt_text = alt_texts[idx]
792
- media_items.append(MediaItem(url=validated_url, type="image", alt_text=alt_text))
801
+ media_items.append(
802
+ MediaItem(url=validated_url, type="image", alt_text=alt_text)
803
+ )
793
804
 
794
805
  # Create containers
795
806
  container_ids = await self._media_manager.create_feed_containers(
@@ -801,8 +812,15 @@ class InstagramClient(SocialMediaPlatform):
801
812
  # Publish
802
813
  result = await self._media_manager.publish_container(container_ids[0])
803
814
 
804
- # Fetch and return post
805
- return await self.get_post(result.media_id)
815
+ # Return minimal Post object without fetching details
816
+ return Post(
817
+ post_id=result.media_id,
818
+ platform=self.platform_name,
819
+ content=caption,
820
+ status=PostStatus.PUBLISHED,
821
+ created_at=datetime.now(),
822
+ url=cast(HttpUrl, result.permalink) if result.permalink else None,
823
+ )
806
824
 
807
825
  async def create_reel(
808
826
  self,
@@ -858,8 +876,15 @@ class InstagramClient(SocialMediaPlatform):
858
876
  # Publish
859
877
  result = await self._media_manager.publish_container(container_id)
860
878
 
861
- # Fetch and return post
862
- return await self.get_post(result.media_id)
879
+ # Return minimal Post object without fetching details
880
+ return Post(
881
+ post_id=result.media_id,
882
+ platform=self.platform_name,
883
+ content=caption,
884
+ status=PostStatus.PUBLISHED,
885
+ created_at=datetime.now(),
886
+ url=cast(HttpUrl, result.permalink) if result.permalink else None,
887
+ )
863
888
 
864
889
  async def create_story(
865
890
  self,
@@ -902,8 +927,15 @@ class InstagramClient(SocialMediaPlatform):
902
927
  # Publish
903
928
  result = await self._media_manager.publish_container(container_id)
904
929
 
905
- # Fetch and return post
906
- return await self.get_post(result.media_id)
930
+ # Return minimal Post object without fetching details
931
+ return Post(
932
+ post_id=result.media_id,
933
+ platform=self.platform_name,
934
+ content=None, # Stories don't have captions
935
+ status=PostStatus.PUBLISHED,
936
+ created_at=datetime.now(),
937
+ url=cast(HttpUrl, result.permalink) if result.permalink else None,
938
+ )
907
939
 
908
940
  # ==================== Helper Methods ====================
909
941
 
@@ -8,7 +8,8 @@ API Documentation: https://learn.microsoft.com/en-us/linkedin/marketing/communit
8
8
  """
9
9
 
10
10
  import contextlib
11
- from datetime import datetime
11
+ import os
12
+ from datetime import datetime, timedelta
12
13
  from typing import Any, cast
13
14
  from urllib.parse import quote
14
15
 
@@ -202,11 +203,15 @@ class LinkedInClient(SocialMediaPlatform):
202
203
  LinkedIn access tokens typically expire after 60 days. Use the
203
204
  refresh token to obtain a new access token.
204
205
 
206
+ Requires LINKEDIN_CLIENT_ID and LINKEDIN_CLIENT_SECRET environment
207
+ variables, or provide OAuth credentials via PlatformFactory.
208
+
205
209
  Returns:
206
210
  Updated credentials with new access token.
207
211
 
208
212
  Raises:
209
- PlatformAuthError: If token refresh fails.
213
+ PlatformAuthError: If token refresh fails or OAuth credentials
214
+ are missing.
210
215
  """
211
216
  if not self.credentials.refresh_token:
212
217
  raise PlatformAuthError(
@@ -214,13 +219,64 @@ class LinkedInClient(SocialMediaPlatform):
214
219
  platform=self.platform_name,
215
220
  )
216
221
 
217
- # Note: LinkedIn OAuth token refresh requires making a request to
218
- # https://www.linkedin.com/oauth/v2/accessToken
219
- # This is simplified for demonstration
220
- raise PlatformAuthError(
221
- "Token refresh not yet implemented. Please re-authenticate.",
222
- platform=self.platform_name,
223
- )
222
+ # Get OAuth credentials from environment
223
+ client_id = os.getenv("LINKEDIN_CLIENT_ID")
224
+ client_secret = os.getenv("LINKEDIN_CLIENT_SECRET")
225
+
226
+ if not client_id or not client_secret:
227
+ raise PlatformAuthError(
228
+ "LinkedIn OAuth credentials required for token refresh. "
229
+ "Set LINKEDIN_CLIENT_ID and LINKEDIN_CLIENT_SECRET environment "
230
+ "variables, or use PlatformFactory for automatic token refresh.",
231
+ platform=self.platform_name,
232
+ )
233
+
234
+ # Make token refresh request
235
+ token_url = "https://www.linkedin.com/oauth/v2/accessToken"
236
+ params = {
237
+ "grant_type": "refresh_token",
238
+ "refresh_token": self.credentials.refresh_token,
239
+ "client_id": client_id,
240
+ "client_secret": client_secret,
241
+ }
242
+
243
+ try:
244
+ async with httpx.AsyncClient() as client:
245
+ response = await client.post(
246
+ token_url,
247
+ data=params,
248
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
249
+ timeout=30.0,
250
+ )
251
+ response.raise_for_status()
252
+ token_data = response.json()
253
+
254
+ except httpx.HTTPStatusError as e:
255
+ raise PlatformAuthError(
256
+ f"Failed to refresh token: {e.response.text}",
257
+ platform=self.platform_name,
258
+ status_code=e.response.status_code,
259
+ ) from e
260
+
261
+ except httpx.HTTPError as e:
262
+ raise PlatformAuthError(
263
+ f"Network error refreshing token: {e}",
264
+ platform=self.platform_name,
265
+ ) from e
266
+
267
+ # Update credentials
268
+ self.credentials.access_token = token_data["access_token"]
269
+
270
+ # LinkedIn might provide new refresh token
271
+ if "refresh_token" in token_data:
272
+ self.credentials.refresh_token = token_data["refresh_token"]
273
+
274
+ # Calculate expiry
275
+ if "expires_in" in token_data:
276
+ expires_in = int(token_data["expires_in"])
277
+ self.credentials.expires_at = datetime.now() + timedelta(seconds=expires_in)
278
+
279
+ return self.credentials
224
280
 
225
281
  async def is_authenticated(self) -> bool:
226
282
  """Check if LinkedIn credentials are valid.
@@ -370,8 +426,16 @@ class LinkedInClient(SocialMediaPlatform):
370
426
  platform=self.platform_name,
371
427
  )
372
428
 
373
- # Fetch full post details
374
- return await self.get_post(post_id)
429
+ # Return minimal Post object without fetching details
430
+ return Post(
431
+ post_id=post_id,
432
+ platform=self.platform_name,
433
+ content=request.content or "",
434
+ status=PostStatus.PUBLISHED,
435
+ created_at=datetime.now(),
436
+ author_id=self.author_urn,
437
+ raw_data=response.data,
438
+ )
375
439
 
376
440
  except httpx.HTTPError as e:
377
441
  raise PlatformError(
@@ -231,24 +231,18 @@ class TikTokClient(SocialMediaPlatform):
231
231
  wait_for_publish=True,
232
232
  )
233
233
 
234
- # 4. Return post - either fetch full details or create minimal Post
235
- if upload_result.video_id:
236
- # Fetch the created post to return full Post object
237
- return await self.get_post(upload_result.video_id)
238
-
239
- # For private/SELF_ONLY posts, TikTok may not return video_id
240
- # Return a minimal Post object with publish_id
234
+ # 4. Return minimal Post object without fetching details
241
235
  return Post(
242
- post_id=upload_result.publish_id,
236
+ post_id=upload_result.video_id or upload_result.publish_id,
243
237
  platform=self.platform_name,
244
238
  content=request.content,
245
239
  status=PostStatus.PUBLISHED,
246
240
  created_at=datetime.now(),
247
241
  raw_data={
248
242
  "publish_id": upload_result.publish_id,
243
+ "video_id": upload_result.video_id,
249
244
  "upload_status": upload_result.status,
250
245
  "privacy_level": privacy_level.value,
251
- "note": "Video published but video_id not returned (common for private posts)",
252
246
  },
253
247
  )
254
248
 
@@ -718,7 +718,9 @@ class TikTokMediaManager:
718
718
  os.remove(temp_file_path)
719
719
  logger.debug(f"Cleaned up temp file: {temp_file_path}")
720
720
  except OSError as e:
721
- logger.warning(f"Failed to clean up temp file {temp_file_path}: {e}")
721
+ logger.warning(
722
+ f"Failed to clean up temp file {temp_file_path}: {e}"
723
+ )
722
724
 
723
725
  def _normalize_chunk_size(self, chunk_size: int, file_size: int) -> int:
724
726
  """Normalize chunk size to TikTok's requirements.
@@ -32,8 +32,10 @@ from marqetive.core.models import (
32
32
  PostCreateRequest,
33
33
  PostStatus,
34
34
  PostUpdateRequest,
35
+ ProgressStatus,
35
36
  )
36
37
  from marqetive.platforms.twitter.media import TwitterMediaManager
38
+ from marqetive.platforms.twitter.models import TwitterPostRequest
37
39
 
38
40
 
39
41
  class TwitterClient(SocialMediaPlatform):
@@ -226,6 +228,16 @@ class TwitterClient(SocialMediaPlatform):
226
228
  if media_ids:
227
229
  tweet_params["media_ids"] = media_ids
228
230
 
231
+ # Check for reply_to_post_id (used for threads and replies)
232
+ reply_to_id = getattr(request, "reply_to_post_id", None)
233
+ if reply_to_id:
234
+ tweet_params["in_reply_to_tweet_id"] = reply_to_id
235
+
236
+ # Check for quote_post_id (used for quote tweets)
237
+ quote_id = getattr(request, "quote_post_id", None)
238
+ if quote_id:
239
+ tweet_params["quote_tweet_id"] = quote_id
240
+
229
241
  response = self._tweepy_client.create_tweet(**tweet_params, user_auth=False)
230
242
  tweet_id = response.data["id"] # type: ignore[index]
231
243
 
@@ -530,6 +542,87 @@ class TwitterClient(SocialMediaPlatform):
530
542
  media_type=media_type,
531
543
  ) from e
532
544
 
545
+ # ==================== Retweet Methods ====================
546
+
547
+ async def retweet(self, tweet_id: str) -> bool:
548
+ """Retweet a tweet.
549
+
550
+ Args:
551
+ tweet_id: ID of the tweet to retweet.
552
+
553
+ Returns:
554
+ True if retweet was successful.
555
+
556
+ Raises:
557
+ PostNotFoundError: If tweet doesn't exist.
558
+ PlatformError: If retweet fails.
559
+ RuntimeError: If client not used as context manager.
560
+ """
561
+ if not self._tweepy_client:
562
+ raise RuntimeError("Client must be used as async context manager")
563
+
564
+ try:
565
+ response = self._tweepy_client.retweet(tweet_id, user_auth=False)
566
+ return response.data.get("retweeted", False) # type: ignore[union-attr]
567
+
568
+ except tweepy.errors.NotFound as e: # type: ignore[attr-defined]
569
+ raise PostNotFoundError(
570
+ post_id=tweet_id,
571
+ platform=self.platform_name,
572
+ status_code=404,
573
+ ) from e
574
+ except tweepy.TweepyException as e:
575
+ if "429" in str(e):
576
+ raise RateLimitError(
577
+ "Twitter rate limit exceeded",
578
+ platform=self.platform_name,
579
+ status_code=429,
580
+ ) from e
581
+ raise PlatformError(
582
+ f"Failed to retweet: {e}",
583
+ platform=self.platform_name,
584
+ ) from e
585
+
586
+ async def unretweet(self, tweet_id: str) -> bool:
587
+ """Undo a retweet (unretweet).
588
+
589
+ Args:
590
+ tweet_id: ID of the tweet to unretweet.
591
+
592
+ Returns:
593
+ True if unretweet was successful.
594
+
595
+ Raises:
596
+ PostNotFoundError: If tweet doesn't exist.
597
+ PlatformError: If unretweet fails.
598
+ RuntimeError: If client not used as context manager.
599
+ """
600
+ if not self._tweepy_client:
601
+ raise RuntimeError("Client must be used as async context manager")
602
+
603
+ try:
604
+ response = self._tweepy_client.unretweet(tweet_id, user_auth=False)
605
+ # Response indicates retweeted=False when successfully unretweeted
606
+ return not response.data.get("retweeted", True) # type: ignore[union-attr]
607
+
608
+ except tweepy.errors.NotFound as e: # type: ignore[attr-defined]
609
+ raise PostNotFoundError(
610
+ post_id=tweet_id,
611
+ platform=self.platform_name,
612
+ status_code=404,
613
+ ) from e
614
+ except tweepy.TweepyException as e:
615
+ if "429" in str(e):
616
+ raise RateLimitError(
617
+ "Twitter rate limit exceeded",
618
+ platform=self.platform_name,
619
+ status_code=429,
620
+ ) from e
621
+ raise PlatformError(
622
+ f"Failed to unretweet: {e}",
623
+ platform=self.platform_name,
624
+ ) from e
625
+
533
626
  # ==================== Thread Methods ====================
534
627
 
535
628
  async def create_thread(
@@ -567,9 +660,6 @@ class TwitterClient(SocialMediaPlatform):
567
660
  ... for post in thread_posts:
568
661
  ... print(f"Tweet {post.post_id}: {post.content}")
569
662
  """
570
- from marqetive.core.models import ProgressStatus
571
- from marqetive.platforms.twitter.models import TwitterPostRequest
572
-
573
663
  if not posts:
574
664
  raise ValidationError(
575
665
  "At least one tweet is required for thread creation",
@@ -17,7 +17,99 @@ from typing import Any
17
17
  import aiofiles
18
18
  import httpx
19
19
 
20
- from marqetive.utils.media import detect_mime_type, format_file_size
20
+ from marqetive.core.exceptions import ValidationError
21
+ from marqetive.utils.media import detect_mime_type, format_file_size, validate_media_url
22
+
23
+ # System directories that should never be written to
24
+ # Includes both standard paths and macOS /private/* equivalents
25
+ _BLOCKED_SYSTEM_DIRS: frozenset[str] = frozenset(
26
+ {
27
+ "/etc",
28
+ "/private/etc",
29
+ "/usr",
30
+ "/bin",
31
+ "/sbin",
32
+ "/var",
33
+ "/private/var",
34
+ "/root",
35
+ "/lib",
36
+ "/lib64",
37
+ "/boot",
38
+ }
39
+ )
40
+
41
+
42
+ def _validate_path(
43
+ file_path: str,
44
+ *,
45
+ allowed_base_dirs: set[str] | None = None,
46
+ ) -> str:
47
+ """Validate and normalize a file path for security.
48
+
49
+ Prevents path traversal attacks by:
50
+ 1. Checking for null bytes (injection attack)
51
+ 2. Resolving to absolute path (handles .. and symlinks)
52
+ 3. Blocking writes to sensitive system directories
53
+
54
+ Args:
55
+ file_path: Path to validate.
56
+ allowed_base_dirs: Optional set of allowed base directories.
57
+ If provided, path must be within one of these directories.
58
+
59
+ Returns:
60
+ Normalized absolute path.
61
+
62
+ Raises:
63
+ ValidationError: If path is invalid or blocked.
64
+
65
+ Example:
66
+ >>> _validate_path("/tmp/myfile.txt")
67
+ '/tmp/myfile.txt'
68
+ >>> _validate_path("../../../etc/passwd") # raises ValidationError
69
+ """
70
+ # Check for null bytes (path injection attack)
71
+ if "\x00" in file_path:
72
+ raise ValidationError(
73
+ "Path contains null bytes",
74
+ platform="file_handlers",
75
+ field="file_path",
76
+ )
77
+
78
+ # Resolve to absolute path (handles .., symlinks, etc.)
79
+ try:
80
+ resolved = Path(file_path).resolve()
81
+ resolved_str = str(resolved)
82
+ except (OSError, RuntimeError) as e:
83
+ raise ValidationError(
84
+ f"Invalid path: {e}",
85
+ platform="file_handlers",
86
+ field="file_path",
87
+ ) from e
88
+
89
+ # Block writes to system directories
90
+ for blocked in _BLOCKED_SYSTEM_DIRS:
91
+ if resolved_str.startswith(blocked + "/") or resolved_str == blocked:
92
+ raise ValidationError(
93
+ f"Writing to system directory '{blocked}' is not allowed",
94
+ platform="file_handlers",
95
+ field="file_path",
96
+ )
97
+
98
+ # If allowed_base_dirs specified, validate path is within them
99
+ if allowed_base_dirs:
100
+ is_allowed = any(
101
+ resolved_str.startswith(str(Path(base).resolve()) + "/")
102
+ or resolved_str == str(Path(base).resolve())
103
+ for base in allowed_base_dirs
104
+ )
105
+ if not is_allowed:
106
+ raise ValidationError(
107
+ f"Path '{file_path}' is outside allowed directories",
108
+ platform="file_handlers",
109
+ field="file_path",
110
+ )
111
+
112
+ return resolved_str
21
113
 
22
114
 
23
115
  class DownloadProgress:
@@ -68,8 +160,9 @@ async def download_file(
68
160
  chunk_size: int = 8192,
69
161
  progress_callback: Callable[[DownloadProgress], None] | None = None,
70
162
  timeout: float = 300.0,
163
+ validate_url: bool = True,
71
164
  ) -> str:
72
- """Download a file from URL asynchronously.
165
+ """Download a file from URL asynchronously with SSRF protection.
73
166
 
74
167
  Args:
75
168
  url: URL to download from.
@@ -77,11 +170,14 @@ async def download_file(
77
170
  chunk_size: Size of chunks to download (default: 8KB).
78
171
  progress_callback: Optional callback function called with progress updates.
79
172
  timeout: Request timeout in seconds (default: 5 minutes).
173
+ validate_url: If True (default), validate URL for SSRF protection.
174
+ Blocks private IPs, localhost, and other internal addresses.
80
175
 
81
176
  Returns:
82
177
  Path to the downloaded file.
83
178
 
84
179
  Raises:
180
+ ValidationError: If URL fails security validation.
85
181
  httpx.HTTPError: If download fails.
86
182
  IOError: If file write fails.
87
183
 
@@ -95,6 +191,14 @@ async def download_file(
95
191
  ... progress_callback=on_progress
96
192
  ... )
97
193
  """
194
+ # Validate URL for SSRF protection (blocks private IPs, localhost, etc.)
195
+ if validate_url:
196
+ url = validate_media_url(
197
+ url,
198
+ block_private_ips=True,
199
+ platform="file_handlers",
200
+ )
201
+
98
202
  # Create temp file if no destination specified
99
203
  if destination is None:
100
204
  temp_fd, destination = tempfile.mkstemp()
@@ -131,8 +235,9 @@ async def download_to_memory(
131
235
  *,
132
236
  max_size: int | None = None,
133
237
  timeout: float = 60.0,
238
+ validate_url: bool = True,
134
239
  ) -> bytes:
135
- """Download a file into memory.
240
+ """Download a file into memory with SSRF protection.
136
241
 
137
242
  Useful for small files that need to be processed immediately.
138
243
 
@@ -140,11 +245,14 @@ async def download_to_memory(
140
245
  url: URL to download from.
141
246
  max_size: Maximum allowed file size in bytes (raises ValueError if exceeded).
142
247
  timeout: Request timeout in seconds (default: 1 minute).
248
+ validate_url: If True (default), validate URL for SSRF protection.
249
+ Blocks private IPs, localhost, and other internal addresses.
143
250
 
144
251
  Returns:
145
252
  File content as bytes.
146
253
 
147
254
  Raises:
255
+ ValidationError: If URL fails security validation.
148
256
  httpx.HTTPError: If download fails.
149
257
  ValueError: If file exceeds max_size.
150
258
 
@@ -154,6 +262,14 @@ async def download_to_memory(
154
262
  ... max_size=1024 * 1024 # 1MB limit
155
263
  ... )
156
264
  """
265
+ # Validate URL for SSRF protection (blocks private IPs, localhost, etc.)
266
+ if validate_url:
267
+ url = validate_media_url(
268
+ url,
269
+ block_private_ips=True,
270
+ platform="file_handlers",
271
+ )
272
+
157
273
  async with httpx.AsyncClient(timeout=timeout) as client:
158
274
  response = await client.get(url)
159
275
  response.raise_for_status()
@@ -242,53 +358,75 @@ async def read_file_bytes(file_path: str) -> bytes:
242
358
  return await f.read()
243
359
 
244
360
 
245
- async def write_file_bytes(file_path: str, content: bytes) -> None:
246
- """Write bytes to file asynchronously.
361
+ async def write_file_bytes(
362
+ file_path: str,
363
+ content: bytes,
364
+ *,
365
+ allowed_base_dirs: set[str] | None = None,
366
+ ) -> None:
367
+ """Write bytes to file asynchronously with path validation.
247
368
 
248
369
  Args:
249
370
  file_path: Path where file should be written.
250
371
  content: Content to write.
372
+ allowed_base_dirs: Optional set of allowed base directories.
373
+ If provided, path must be within one of these directories.
251
374
 
252
375
  Raises:
376
+ ValidationError: If path is invalid or blocked.
253
377
  IOError: If write fails.
254
378
 
255
379
  Example:
256
- >>> await write_file_bytes('/path/to/output.bin', b'some data')
380
+ >>> await write_file_bytes('/tmp/output.bin', b'some data')
257
381
  """
382
+ # Validate path for security (prevents path traversal)
383
+ validated_path = _validate_path(file_path, allowed_base_dirs=allowed_base_dirs)
384
+
258
385
  # Ensure parent directory exists
259
- parent_dir = Path(file_path).parent
386
+ parent_dir = Path(validated_path).parent
260
387
  parent_dir.mkdir(parents=True, exist_ok=True)
261
388
 
262
- async with aiofiles.open(file_path, "wb") as f:
389
+ async with aiofiles.open(validated_path, "wb") as f:
263
390
  await f.write(content)
264
391
 
265
392
 
266
- async def copy_file_async(source: str, destination: str) -> None:
267
- """Copy file asynchronously.
393
+ async def copy_file_async(
394
+ source: str,
395
+ destination: str,
396
+ *,
397
+ allowed_base_dirs: set[str] | None = None,
398
+ ) -> None:
399
+ """Copy file asynchronously with path validation.
268
400
 
269
401
  Args:
270
402
  source: Source file path.
271
403
  destination: Destination file path.
404
+ allowed_base_dirs: Optional set of allowed base directories.
405
+ If provided, destination must be within one of these directories.
272
406
 
273
407
  Raises:
274
408
  FileNotFoundError: If source doesn't exist.
409
+ ValidationError: If destination path is invalid or blocked.
275
410
  IOError: If copy fails.
276
411
 
277
412
  Example:
278
- >>> await copy_file_async('/path/to/source.txt', '/path/to/dest.txt')
413
+ >>> await copy_file_async('/path/to/source.txt', '/tmp/dest.txt')
279
414
  """
280
415
  if not os.path.exists(source):
281
416
  raise FileNotFoundError(f"Source file not found: {source}")
282
417
 
418
+ # Validate destination path for security (prevents path traversal)
419
+ validated_dest = _validate_path(destination, allowed_base_dirs=allowed_base_dirs)
420
+
283
421
  # Ensure destination directory exists
284
- dest_dir = Path(destination).parent
422
+ dest_dir = Path(validated_dest).parent
285
423
  dest_dir.mkdir(parents=True, exist_ok=True)
286
424
 
287
425
  # Read from source and write to destination
288
426
  async with aiofiles.open(source, "rb") as src:
289
427
  content = await src.read()
290
428
 
291
- async with aiofiles.open(destination, "wb") as dest:
429
+ async with aiofiles.open(validated_dest, "wb") as dest:
292
430
  await dest.write(content)
293
431
 
294
432
 
@@ -0,0 +1,50 @@
1
+ """Helper functions for common API operations."""
2
+
3
+ import json
4
+ from typing import Any
5
+ from urllib.parse import parse_qs, urlparse
6
+
7
+
8
+ def format_response(
9
+ data: dict[str, Any], *, pretty: bool = False, indent: int = 2
10
+ ) -> str:
11
+ """Format API response data as a string.
12
+
13
+ Args:
14
+ data: The response data dictionary
15
+ pretty: Whether to format with indentation (default: False)
16
+ indent: Number of spaces for indentation if pretty=True (default: 2)
17
+
18
+ Returns:
19
+ Formatted string representation of the response
20
+
21
+ Example:
22
+ >>> data = {"user": "john", "status": "active"}
23
+ >>> print(format_response(data, pretty=True))
24
+ {
25
+ "user": "john",
26
+ "status": "active"
27
+ }
28
+ """
29
+ if pretty:
30
+ return json.dumps(data, indent=indent, sort_keys=True)
31
+ return json.dumps(data)
32
+
33
+
34
+ def parse_query_params(url: str) -> dict[str, Any]:
35
+ """Parse query parameters from a URL.
36
+
37
+ Args:
38
+ url: The URL string to parse
39
+
40
+ Returns:
41
+ Dictionary of query parameters
42
+
43
+ Example:
44
+ >>> url = "https://api.example.com/users?page=1&limit=10"
45
+ >>> params = parse_query_params(url)
46
+ >>> print(params)
47
+ {'page': ['1'], 'limit': ['10']}
48
+ """
49
+ parsed = urlparse(url)
50
+ return dict(parse_qs(parsed.query))
@@ -17,6 +17,8 @@ from pathlib import Path
17
17
  from typing import Literal
18
18
  from urllib.parse import urlparse
19
19
 
20
+ import aiofiles
21
+
20
22
  from marqetive.core.exceptions import ValidationError
21
23
 
22
24
  # Initialize mimetypes database
@@ -314,8 +316,6 @@ async def chunk_file(
314
316
  >>> async for chunk in chunk_file('large_video.mp4', chunk_size=5*1024*1024):
315
317
  ... await upload_chunk(chunk)
316
318
  """
317
- import aiofiles
318
-
319
319
  if not os.path.exists(file_path):
320
320
  raise FileNotFoundError(f"File not found: {file_path}")
321
321
 
@@ -4,6 +4,7 @@ This module provides utilities for refreshing OAuth2 access tokens across
4
4
  different social media platforms.
5
5
  """
6
6
 
7
+ import base64
7
8
  import logging
8
9
  import re
9
10
  from datetime import datetime, timedelta
@@ -147,8 +148,6 @@ async def refresh_twitter_token(
147
148
  ... os.getenv("TWITTER_CLIENT_SECRET")
148
149
  ... )
149
150
  """
150
- import base64
151
-
152
151
  if not credentials.refresh_token:
153
152
  raise PlatformAuthError(
154
153
  "No refresh token available",
@@ -7,6 +7,7 @@ with exponential backoff for async functions.
7
7
  import asyncio
8
8
  import functools
9
9
  import logging
10
+ import random
10
11
  from collections.abc import Awaitable, Callable
11
12
  from dataclasses import dataclass
12
13
  from typing import Any, TypeVar
@@ -51,8 +52,6 @@ class BackoffConfig:
51
52
  Returns:
52
53
  Delay in seconds.
53
54
  """
54
- import random
55
-
56
55
  delay = min(
57
56
  self.base_delay * (self.exponential_base**attempt),
58
57
  self.max_delay,
@@ -1,99 +0,0 @@
1
- """Helper functions for common API operations."""
2
-
3
- from typing import Any
4
- from urllib.parse import parse_qs, urlencode, urlparse
5
-
6
-
7
- def format_response(
8
- data: dict[str, Any], *, pretty: bool = False, indent: int = 2
9
- ) -> str:
10
- """Format API response data as a string.
11
-
12
- Args:
13
- data: The response data dictionary
14
- pretty: Whether to format with indentation (default: False)
15
- indent: Number of spaces for indentation if pretty=True (default: 2)
16
-
17
- Returns:
18
- Formatted string representation of the response
19
-
20
- Example:
21
- >>> data = {"user": "john", "status": "active"}
22
- >>> print(format_response(data, pretty=True))
23
- {
24
- "user": "john",
25
- "status": "active"
26
- }
27
- """
28
- import json
29
-
30
- if pretty:
31
- return json.dumps(data, indent=indent, sort_keys=True)
32
- return json.dumps(data)
33
-
34
-
35
- def parse_query_params(url: str) -> dict[str, Any]:
36
- """Parse query parameters from a URL.
37
-
38
- Args:
39
- url: The URL string to parse
40
-
41
- Returns:
42
- Dictionary of query parameters
43
-
44
- Example:
45
- >>> url = "https://api.example.com/users?page=1&limit=10"
46
- >>> params = parse_query_params(url)
47
- >>> print(params)
48
- {'page': ['1'], 'limit': ['10']}
49
- """
50
- parsed = urlparse(url)
51
- return dict(parse_qs(parsed.query))
52
-
53
-
54
- def build_query_string(params: dict[str, Any]) -> str:
55
- """Build a query string from a dictionary of parameters.
56
-
57
- Args:
58
- params: Dictionary of query parameters
59
-
60
- Returns:
61
- URL-encoded query string
62
-
63
- Example:
64
- >>> params = {"page": 1, "limit": 10, "sort": "name"}
65
- >>> query = build_query_string(params)
66
- >>> print(query)
67
- page=1&limit=10&sort=name
68
- """
69
- return urlencode(params)
70
-
71
-
72
- def merge_headers(
73
- default_headers: dict[str, str] | None = None,
74
- custom_headers: dict[str, str] | None = None,
75
- ) -> dict[str, str]:
76
- """Merge default and custom headers.
77
-
78
- Custom headers take precedence over default headers.
79
-
80
- Args:
81
- default_headers: Default headers dictionary
82
- custom_headers: Custom headers to merge
83
-
84
- Returns:
85
- Merged headers dictionary
86
-
87
- Example:
88
- >>> defaults = {"Content-Type": "application/json"}
89
- >>> custom = {"Authorization": "Bearer token"}
90
- >>> headers = merge_headers(defaults, custom)
91
- >>> print(headers)
92
- {'Content-Type': 'application/json', 'Authorization': 'Bearer token'}
93
- """
94
- result = {}
95
- if default_headers:
96
- result.update(default_headers)
97
- if custom_headers:
98
- result.update(custom_headers)
99
- return result
@@ -1,240 +0,0 @@
1
- """Token validation utilities for checking credential validity.
2
-
3
- This module provides utilities for validating OAuth tokens and determining
4
- if they need to be refreshed.
5
- """
6
-
7
- import re
8
- from datetime import datetime, timedelta
9
- from typing import Any
10
-
11
- from marqetive.core.models import AuthCredentials
12
-
13
-
14
- def is_token_expired(
15
- expires_at: datetime | None,
16
- threshold_minutes: int = 5,
17
- ) -> bool:
18
- """Check if a token has expired or will expire soon.
19
-
20
- Args:
21
- expires_at: Token expiration timestamp.
22
- threshold_minutes: Consider expired if expires within this many minutes.
23
-
24
- Returns:
25
- True if token is expired or will expire soon, False otherwise.
26
-
27
- Example:
28
- >>> from datetime import datetime, timedelta
29
- >>> expires = datetime.now() + timedelta(minutes=3)
30
- >>> is_token_expired(expires, threshold_minutes=5)
31
- True
32
- >>> expires = datetime.now() + timedelta(hours=1)
33
- >>> is_token_expired(expires, threshold_minutes=5)
34
- False
35
- """
36
- if expires_at is None:
37
- # No expiry means token doesn't expire
38
- return False
39
-
40
- threshold = datetime.now() + timedelta(minutes=threshold_minutes)
41
- return expires_at <= threshold
42
-
43
-
44
- def needs_refresh(
45
- credentials: AuthCredentials,
46
- threshold_minutes: int = 5, # noqa: ARG001
47
- ) -> bool:
48
- """Check if credentials need to be refreshed.
49
-
50
- Args:
51
- credentials: Credentials to check.
52
- threshold_minutes: Expiry threshold in minutes.
53
-
54
- Returns:
55
- True if refresh is needed, False otherwise.
56
-
57
- Example:
58
- >>> creds = AuthCredentials(
59
- ... platform="twitter",
60
- ... access_token="token",
61
- ... expires_at=datetime.now() + timedelta(minutes=2)
62
- ... )
63
- >>> needs_refresh(creds)
64
- True
65
- """
66
- return credentials.needs_refresh()
67
-
68
-
69
- def validate_token_format(token: str, min_length: int = 10) -> bool:
70
- """Validate basic token format.
71
-
72
- Checks if token looks valid (not empty, meets minimum length).
73
-
74
- Args:
75
- token: Token string to validate.
76
- min_length: Minimum acceptable token length.
77
-
78
- Returns:
79
- True if token format is valid, False otherwise.
80
-
81
- Example:
82
- >>> validate_token_format("abc123xyz")
83
- False
84
- >>> validate_token_format("a" * 50)
85
- True
86
- """
87
- if not token or not isinstance(token, str):
88
- return False
89
-
90
- # Remove whitespace
91
- token = token.strip()
92
-
93
- # Check minimum length
94
- if len(token) < min_length:
95
- return False
96
-
97
- # Check for obviously invalid tokens
98
- return token.lower() not in ["none", "null", "undefined", ""]
99
-
100
-
101
- def validate_bearer_token(token: str) -> bool:
102
- """Validate Bearer token format.
103
-
104
- Args:
105
- token: Bearer token to validate.
106
-
107
- Returns:
108
- True if token appears valid, False otherwise.
109
-
110
- Example:
111
- >>> validate_bearer_token("ya29.a0AfH6SMB...")
112
- True
113
- >>> validate_bearer_token("invalid")
114
- False
115
- """
116
- # Bearer tokens are typically base64-like strings
117
- if not validate_token_format(token, min_length=20):
118
- return False
119
-
120
- # Check for suspicious patterns
121
- return not re.search(r"[<>\"']", token)
122
-
123
-
124
- def calculate_token_ttl(expires_at: datetime | None) -> timedelta | None:
125
- """Calculate time-to-live for a token.
126
-
127
- Args:
128
- expires_at: Token expiration timestamp.
129
-
130
- Returns:
131
- Time remaining until expiration, or None if no expiry.
132
-
133
- Example:
134
- >>> from datetime import datetime, timedelta
135
- >>> expires = datetime.now() + timedelta(hours=1)
136
- >>> ttl = calculate_token_ttl(expires)
137
- >>> ttl.total_seconds() > 3500 # Approximately 1 hour
138
- True
139
- """
140
- if expires_at is None:
141
- return None
142
-
143
- now = datetime.now()
144
- if expires_at <= now:
145
- return timedelta(0)
146
-
147
- return expires_at - now
148
-
149
-
150
- def should_proactively_refresh(
151
- credentials: AuthCredentials,
152
- refresh_threshold_minutes: int = 5,
153
- ) -> bool:
154
- """Determine if token should be proactively refreshed.
155
-
156
- Checks if token will expire soon and if refresh token is available.
157
-
158
- Args:
159
- credentials: Credentials to check.
160
- refresh_threshold_minutes: Refresh if expires within this many minutes.
161
-
162
- Returns:
163
- True if should proactively refresh, False otherwise.
164
-
165
- Example:
166
- >>> creds = AuthCredentials(
167
- ... platform="twitter",
168
- ... access_token="token",
169
- ... refresh_token="refresh",
170
- ... expires_at=datetime.now() + timedelta(minutes=3)
171
- ... )
172
- >>> should_proactively_refresh(creds)
173
- True
174
- """
175
- # Need refresh token to refresh
176
- if not credentials.refresh_token:
177
- return False
178
-
179
- # Check if expiring soon
180
- return is_token_expired(credentials.expires_at, refresh_threshold_minutes)
181
-
182
-
183
- def is_credentials_complete(credentials: AuthCredentials) -> bool:
184
- """Check if credentials have all required fields.
185
-
186
- Args:
187
- credentials: Credentials to validate.
188
-
189
- Returns:
190
- True if credentials are complete, False otherwise.
191
-
192
- Example:
193
- >>> creds = AuthCredentials(
194
- ... platform="twitter",
195
- ... access_token="token"
196
- ... )
197
- >>> is_credentials_complete(creds)
198
- True
199
- """
200
- # Must have platform and access token
201
- if not credentials.platform or not credentials.access_token:
202
- return False
203
-
204
- # Access token must be valid format
205
- return validate_token_format(credentials.access_token)
206
-
207
-
208
- def get_token_health_status(credentials: AuthCredentials) -> dict[str, Any]:
209
- """Get comprehensive health status of credentials.
210
-
211
- Args:
212
- credentials: Credentials to analyze.
213
-
214
- Returns:
215
- Dictionary with health information.
216
-
217
- Example:
218
- >>> creds = AuthCredentials(
219
- ... platform="twitter",
220
- ... access_token="token",
221
- ... expires_at=datetime.now() + timedelta(hours=1)
222
- ... )
223
- >>> status = get_token_health_status(creds)
224
- >>> status["is_valid"]
225
- True
226
- >>> status["needs_refresh"]
227
- False
228
- """
229
- ttl = calculate_token_ttl(credentials.expires_at)
230
-
231
- return {
232
- "is_valid": credentials.is_valid(),
233
- "is_expired": credentials.is_expired(),
234
- "needs_refresh": credentials.needs_refresh(),
235
- "has_refresh_token": credentials.refresh_token is not None,
236
- "time_to_expiry_seconds": ttl.total_seconds() if ttl else None,
237
- "should_proactively_refresh": should_proactively_refresh(credentials),
238
- "status": credentials.status.value,
239
- "is_complete": is_credentials_complete(credentials),
240
- }
File without changes