tumblrbot 1.9.6__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tumblrbot/utils/models.py CHANGED
@@ -1,236 +1,225 @@
1
- from getpass import getpass
2
- from pathlib import Path
3
- from tomllib import loads
4
- from typing import TYPE_CHECKING, Annotated, Any, Literal, Self, override
5
-
6
- from openai.types import ChatModel # noqa: TC002
7
- from pydantic import BaseModel, ConfigDict, Field, NonNegativeFloat, NonNegativeInt, PlainSerializer, PositiveFloat, PositiveInt, model_validator
8
- from pydantic.json_schema import SkipJsonSchema # noqa: TC002
9
- from requests_oauthlib import OAuth1Session
10
- from rich import print as rich_print
11
- from rich.panel import Panel
12
- from rich.prompt import Prompt
13
- from tomlkit import comment, document, dumps # pyright: ignore[reportUnknownVariableType]
14
-
15
- if TYPE_CHECKING:
16
- from collections.abc import Generator
17
-
18
-
19
- class FullyValidatedModel(BaseModel):
20
- model_config = ConfigDict(
21
- extra="ignore",
22
- validate_assignment=True,
23
- validate_default=True,
24
- validate_return=True,
25
- validate_by_name=True,
26
- )
27
-
28
-
29
- class FileSyncSettings(FullyValidatedModel):
30
- @classmethod
31
- def get_toml_file(cls) -> Path:
32
- return Path(f"{cls.__name__.lower()}.toml")
33
-
34
- @classmethod
35
- def load(cls) -> Self:
36
- toml_file = cls.get_toml_file()
37
- data = loads(toml_file.read_text("utf_8")) if toml_file.exists() else {}
38
- return cls.model_validate(data)
39
-
40
- @model_validator(mode="after")
41
- def dump(self) -> Self:
42
- toml_table = document()
43
-
44
- for (name, field), value in zip(self.__class__.model_fields.items(), self.model_dump(mode="json").values(), strict=True):
45
- if field.description is not None:
46
- for line in field.description.split(". "):
47
- toml_table.add(comment(f"{line.removesuffix('.')}."))
48
-
49
- toml_table[name] = value
50
-
51
- self.get_toml_file().write_text(dumps(toml_table), encoding="utf_8")
52
-
53
- return self
54
-
55
-
56
- class Config(FileSyncSettings):
57
- # Downloading Posts & Writing Examples
58
- download_blog_identifiers: list[str] = Field([], description="The identifiers of the blogs which post data will be downloaded from.")
59
- data_directory: Path = Field(Path("data"), description="Where to store downloaded post data.")
60
-
61
- # Writing Examples
62
- post_limit: NonNegativeInt = Field(0, description="The number of the most recent posts from each blog that should be included in the training data.")
63
- moderation_batch_size: PositiveInt = Field(25, description="The number of posts at a time to submit to the OpenAI moderation API.")
64
- custom_prompts_file: Path = Field(Path("custom_prompts.jsonl"), description="Where to read in custom prompts from.")
65
- filtered_words: list[str] = Field([], description="A case-insensitive list of disallowed words used to filter out training data. Regular expressions are allowed, but must be escaped.")
66
-
67
- # Writing Examples & Fine-Tuning
68
- examples_file: Path = Field(Path("examples.jsonl"), description="Where to output the examples that will be used to fine-tune the model.")
69
-
70
- # Writing Examples & Generating
71
- developer_message: str = Field("You are a Tumblr post bot. Please generate a Tumblr post in accordance with the user's request.", description="The developer message used by the OpenAI API to generate drafts.")
72
- user_message: str = Field("Please write a comical Tumblr post.", description="The user input used by the OpenAI API to generate drafts.")
73
-
74
- # Fine-Tuning
75
- expected_epochs: PositiveInt = Field(3, description="The expected number of epochs fine-tuning will be run for. This will be updated during fine-tuning.")
76
- token_price: PositiveFloat = Field(3, description="The expected price in USD per million tokens during fine-tuning for the current model.")
77
- job_id: str = Field("", description="The fine-tuning job ID that will be polled on next run.")
78
-
79
- # Fine-Tuning & Generating
80
- base_model: ChatModel = Field("gpt-4o-mini-2024-07-18", description="The name of the model that will be fine-tuned by the generated training data.")
81
- fine_tuned_model: str = Field("", description="The name of the OpenAI model that was fine-tuned with your posts.")
82
-
83
- # Generating
84
- upload_blog_identifier: str = Field("", description="The identifier of the blog which generated drafts will be uploaded to. This must be a blog associated with the same account as the configured Tumblr secret tokens.")
85
- draft_count: PositiveInt = Field(100, description="The number of drafts to process. This will affect the number of tokens used with OpenAI")
86
- tags_chance: NonNegativeFloat = Field(0.1, description="The chance to generate tags for any given post. This will use more OpenAI tokens.")
87
- tags_developer_message: str = Field("You will be provided with a block of text, and your task is to extract a very short list of the most important subjects from it.", description="The developer message used to generate tags.")
88
- reblog_blog_identifiers: list[str] = Field([], description="The identifiers of blogs that can be reblogged from when generating drafts.")
89
- reblog_chance: NonNegativeFloat = Field(0.1, description="The chance to generate a reblog of a random post. This will use more OpenAI tokens.")
90
- reblog_user_message: str = Field("Please write a comical Tumblr post in response to the following post:\n\n{}", description="The format string for the user message used to reblog posts.")
91
-
92
- @override
93
- def model_post_init(self, context: object) -> None:
94
- super().model_post_init(context)
95
-
96
- if not self.download_blog_identifiers:
97
- rich_print("Enter the [cyan]identifiers of your blogs[/] that data should be [bold purple]downloaded[/] from, separated by commas.")
98
- self.download_blog_identifiers = list(map(str.strip, Prompt.ask("[bold][Example] [dim]staff.tumblr.com,changes").split(",")))
99
-
100
- if not self.upload_blog_identifier:
101
- rich_print("Enter the [cyan]identifier of your blog[/] that drafts should be [bold purple]uploaded[/] to.")
102
- self.upload_blog_identifier = Prompt.ask("[bold][Example] [dim]staff.tumblr.com or changes").strip()
103
-
104
-
105
- class Tokens(FileSyncSettings):
106
- class Tumblr(FullyValidatedModel):
107
- client_key: str = ""
108
- client_secret: str = ""
109
- resource_owner_key: str = ""
110
- resource_owner_secret: str = ""
111
-
112
- openai_api_key: str = ""
113
- tumblr: Tumblr = Tumblr()
114
-
115
- @override
116
- def model_post_init(self, context: object) -> None:
117
- super().model_post_init(context)
118
-
119
- # Check if any tokens are missing or if the user wants to reset them, then set tokens if necessary.
120
- if not self.openai_api_key:
121
- (self.openai_api_key,) = self.online_token_prompt("https://platform.openai.com/api-keys", "API key")
122
-
123
- if not all(self.tumblr.model_dump().values()):
124
- self.tumblr.client_key, self.tumblr.client_secret = self.online_token_prompt("https://tumblr.com/oauth/apps", "consumer key", "consumer secret")
125
-
126
- # This is the whole OAuth 1.0 process.
127
- # https://requests-oauthlib.readthedocs.io/en/latest/examples/tumblr.html
128
- # We tried setting up OAuth 2.0, but the token refresh process is far too unreliable for this sort of program.
129
- with OAuth1Session(
130
- self.tumblr.client_key,
131
- self.tumblr.client_secret,
132
- ) as oauth_session:
133
- fetch_response = oauth_session.fetch_request_token("http://tumblr.com/oauth/request_token") # pyright: ignore[reportUnknownMemberType]
134
- full_authorize_url = oauth_session.authorization_url("http://tumblr.com/oauth/authorize") # pyright: ignore[reportUnknownMemberType]
135
- (redirect_response,) = self.online_token_prompt(full_authorize_url, "full redirect URL")
136
- oauth_response = oauth_session.parse_authorization_response(redirect_response)
137
-
138
- with OAuth1Session(
139
- self.tumblr.client_key,
140
- self.tumblr.client_secret,
141
- *self.get_oauth_tokens(fetch_response),
142
- verifier=oauth_response["oauth_verifier"],
143
- ) as oauth_session:
144
- oauth_tokens = oauth_session.fetch_access_token("http://tumblr.com/oauth/access_token") # pyright: ignore[reportUnknownMemberType]
145
-
146
- self.tumblr.resource_owner_key, self.tumblr.resource_owner_secret = self.get_oauth_tokens(oauth_tokens)
147
-
148
- @staticmethod
149
- def online_token_prompt(url: str, *tokens: str) -> Generator[str]:
150
- formatted_token_string = " and ".join(f"[cyan]{token}[/]" for token in tokens)
151
-
152
- rich_print(f"Retrieve your {formatted_token_string} from: {url}")
153
- for token in tokens:
154
- yield getpass(f"Enter your {token} (masked): ", echo_char="*").strip()
155
-
156
- rich_print()
157
-
158
- @staticmethod
159
- def get_oauth_tokens(token: dict[str, str]) -> tuple[str, str]:
160
- return token["oauth_token"], token["oauth_token_secret"]
161
-
162
-
163
- class Blog(FullyValidatedModel):
164
- name: str = ""
165
- posts: int = 0
166
- uuid: str = ""
167
-
168
-
169
- class Response(FullyValidatedModel):
170
- blog: Blog = Blog()
171
- posts: list[Any] = []
172
-
173
-
174
- class ResponseModel(FullyValidatedModel):
175
- response: Response
176
-
177
-
178
- class Block(FullyValidatedModel):
179
- type: str = ""
180
- text: str = ""
181
- blocks: list[int] = []
182
-
183
-
184
- class Post(FullyValidatedModel):
185
- blog: SkipJsonSchema[Blog] = Blog()
186
- id: SkipJsonSchema[int] = 0
187
- parent_tumblelog_uuid: SkipJsonSchema[str] = ""
188
- parent_post_id: SkipJsonSchema[int] = 0
189
- reblog_key: SkipJsonSchema[str] = ""
190
-
191
- timestamp: SkipJsonSchema[int] = 0
192
- tags: Annotated[list[str], PlainSerializer(",".join)] = []
193
- state: SkipJsonSchema[Literal["published", "queued", "draft", "private", "unapproved"]] = "draft"
194
-
195
- content: SkipJsonSchema[list[Block]] = []
196
- layout: SkipJsonSchema[list[Block]] = []
197
- trail: SkipJsonSchema[list[Self]] = []
198
-
199
- is_submission: SkipJsonSchema[bool] = False
200
-
201
- def __rich__(self) -> Panel:
202
- return Panel(
203
- str(self),
204
- title="Preview",
205
- subtitle=" ".join(f"#{tag}" for tag in self.tags),
206
- subtitle_align="left",
207
- )
208
-
209
- def __str__(self) -> str:
210
- # This function is really only relevant when a post is already valid, so we don't have to check the block types.
211
- # If it is called on an invalid post, it would also work, but might give strange data.
212
- return "\n\n".join(block.text for block in self.content)
213
-
214
- def valid_text_post(self) -> bool:
215
- # Checks if this post:
216
- # - has any content blocks (some glitched empty posts have no content)
217
- # - only has content blocks of type 'text' (this excludes photo/video/poll/etc posts)
218
- # - is not a submitted post
219
- # - has no ask blocks in the content
220
- return bool(self.content) and all(block.type == "text" for block in self.content) and not (self.is_submission or any(block.type == "ask" for block in self.layout))
221
-
222
-
223
- class Message(FullyValidatedModel):
224
- role: Literal["developer", "user", "assistant"]
225
- content: str
226
-
227
-
228
- class Example(FullyValidatedModel):
229
- messages: list[Message]
230
-
231
- def get_assistant_message(self) -> str:
232
- for message in self.messages:
233
- if message.role == "assistant":
234
- return message.content
235
- msg = "Assistant message not found!"
236
- raise ValueError(msg)
1
+ from getpass import getpass
2
+ from pathlib import Path
3
+ from tomllib import loads
4
+ from typing import TYPE_CHECKING, Annotated, Any, Literal, Self, override
5
+
6
+ from openai.types import ResponsesModel # noqa: TC002
7
+ from pydantic import BaseModel, ConfigDict, Field, NonNegativeFloat, NonNegativeInt, PlainSerializer, PositiveFloat, PositiveInt, model_validator
8
+ from pydantic.json_schema import SkipJsonSchema # noqa: TC002
9
+ from requests_oauthlib import OAuth1Session
10
+ from rich import print as rich_print
11
+ from rich.panel import Panel
12
+ from rich.prompt import Prompt
13
+ from tomlkit import comment, document, dumps # pyright: ignore[reportUnknownVariableType]
14
+
15
+ if TYPE_CHECKING:
16
+ from collections.abc import Generator
17
+
18
+
19
+ class FullyValidatedModel(BaseModel):
20
+ model_config = ConfigDict(
21
+ extra="ignore",
22
+ validate_assignment=True,
23
+ validate_default=True,
24
+ validate_return=True,
25
+ validate_by_name=True,
26
+ )
27
+
28
+
29
+ class FileSyncSettings(FullyValidatedModel):
30
+ @classmethod
31
+ def get_toml_file(cls) -> Path:
32
+ return Path(f"{cls.__name__.lower()}.toml")
33
+
34
+ @classmethod
35
+ def load(cls) -> Self:
36
+ toml_file = cls.get_toml_file()
37
+ data = loads(toml_file.read_text("utf_8")) if toml_file.exists() else {}
38
+ return cls.model_validate(data)
39
+
40
+ @model_validator(mode="after")
41
+ def dump(self) -> Self:
42
+ toml_table = document()
43
+
44
+ for (name, field), value in zip(self.__class__.model_fields.items(), self.model_dump(mode="json").values(), strict=True):
45
+ if field.description is not None:
46
+ for line in field.description.split(". "):
47
+ toml_table.add(comment(f"{line.removesuffix('.')}."))
48
+
49
+ toml_table[name] = value
50
+
51
+ self.get_toml_file().write_text(dumps(toml_table), encoding="utf_8")
52
+
53
+ return self
54
+
55
+
56
+ class Config(FileSyncSettings):
57
+ # Downloading Posts & Writing Examples
58
+ download_blog_identifiers: list[str] = Field([], description="The identifiers of the blogs which post data will be downloaded from.")
59
+ data_directory: Path = Field(Path("data"), description="Where to store downloaded post data.")
60
+
61
+ # Writing Examples
62
+ post_limit: NonNegativeInt = Field(0, description="The number of the most recent posts from each blog that should be included in the training data.")
63
+ moderation_batch_size: PositiveInt = Field(25, description="The number of posts at a time to submit to the OpenAI moderation API.")
64
+ custom_prompts_file: Path = Field(Path("custom_prompts.jsonl"), description="Where to read in custom prompts from.")
65
+ filtered_words: list[str] = Field([], description="A case-insensitive list of disallowed words used to filter out training data. Regular expressions are allowed, but must be escaped.")
66
+
67
+ # Writing Examples & Fine-Tuning
68
+ examples_file: Path = Field(Path("examples.jsonl"), description="Where to output the examples that will be used to fine-tune the model.")
69
+
70
+ # Writing Examples & Generating
71
+ developer_message: str = Field("You are a Tumblr post bot. Please generate a Tumblr post in accordance with the user's request.", description="The developer message used by the OpenAI API to generate drafts.")
72
+ user_message: str = Field("Please write a comical Tumblr post.", description="The user input used by the OpenAI API to generate drafts.")
73
+
74
+ # Fine-Tuning
75
+ expected_epochs: PositiveInt = Field(3, description="The expected number of epochs fine-tuning will be run for. This will be updated during fine-tuning.")
76
+ token_price: PositiveFloat = Field(3, description="The expected price in USD per million tokens during fine-tuning for the current model.")
77
+ job_id: str = Field("", description="The fine-tuning job ID that will be polled on next run.")
78
+
79
+ # Fine-Tuning & Generating
80
+ base_model: ResponsesModel = Field("gpt-4o-mini-2024-07-18", description="The name of the model that will be fine-tuned by the generated training data.")
81
+ fine_tuned_model: str = Field("", description="The name of the OpenAI model that was fine-tuned with your posts.")
82
+
83
+ # Generating
84
+ upload_blog_identifier: str = Field("", description="The identifier of the blog which generated drafts will be uploaded to. This must be a blog associated with the same account as the configured Tumblr secret tokens.")
85
+ draft_count: PositiveInt = Field(100, description="The number of drafts to process. This will affect the number of tokens used with OpenAI")
86
+ tags_chance: NonNegativeFloat = Field(0.1, description="The chance to generate tags for any given post. This will use more OpenAI tokens.")
87
+ tags_developer_message: str = Field("You will be provided with a block of text, and your task is to extract a very short list of the most important subjects from it.", description="The developer message used to generate tags.")
88
+ reblog_blog_identifiers: list[str] = Field([], description="The identifiers of blogs that can be reblogged from when generating drafts.")
89
+ reblog_chance: NonNegativeFloat = Field(0.1, description="The chance to generate a reblog of a random post. This will use more OpenAI tokens.")
90
+ reblog_user_message: str = Field("Please write a comical Tumblr post in response to the following post:\n\n{}", description="The format string for the user message used to reblog posts.")
91
+
92
+ @override
93
+ def model_post_init(self, context: object) -> None:
94
+ super().model_post_init(context)
95
+
96
+ if not self.download_blog_identifiers:
97
+ rich_print("Enter the [cyan]identifiers of your blogs[/] that data should be [bold purple]downloaded[/] from, separated by commas.")
98
+ self.download_blog_identifiers = list(map(str.strip, Prompt.ask("[bold][Example] [dim]staff.tumblr.com,changes").split(",")))
99
+
100
+ if not self.upload_blog_identifier:
101
+ rich_print("Enter the [cyan]identifier of your blog[/] that drafts should be [bold purple]uploaded[/] to.")
102
+ self.upload_blog_identifier = Prompt.ask("[bold][Example] [dim]staff.tumblr.com or changes").strip()
103
+
104
+
105
+ class Tokens(FileSyncSettings):
106
+ class Tumblr(FullyValidatedModel):
107
+ client_key: str = ""
108
+ client_secret: str = ""
109
+ resource_owner_key: str = ""
110
+ resource_owner_secret: str = ""
111
+
112
+ openai_api_key: str = ""
113
+ tumblr: Tumblr = Tumblr()
114
+
115
+ @staticmethod
116
+ def online_token_prompt(url: str, *tokens: str) -> Generator[str]:
117
+ formatted_token_string = " and ".join(f"[cyan]{token}[/]" for token in tokens)
118
+
119
+ rich_print(f"Retrieve your {formatted_token_string} from: {url}")
120
+ for token in tokens:
121
+ yield getpass(f"Enter your {token} (masked): ", echo_char="*").strip()
122
+
123
+ rich_print()
124
+
125
+ @override
126
+ def model_post_init(self, context: object) -> None:
127
+ super().model_post_init(context)
128
+
129
+ # Check if any tokens are missing or if the user wants to reset them, then set tokens if necessary.
130
+ if not self.openai_api_key:
131
+ (self.openai_api_key,) = self.online_token_prompt("https://platform.openai.com/api-keys", "API key")
132
+
133
+ if not all(self.tumblr.model_dump().values()):
134
+ self.tumblr.client_key, self.tumblr.client_secret = self.online_token_prompt("https://tumblr.com/oauth/apps", "consumer key", "consumer secret")
135
+
136
+ # This is the whole OAuth 1.0 process.
137
+ # https://requests-oauthlib.readthedocs.io/en/latest/examples/tumblr.html
138
+ # We tried setting up OAuth 2.0, but the token refresh process is far too unreliable for this sort of program.
139
+ with OAuth1Session(**self.tumblr.model_dump()) as session:
140
+ session.fetch_request_token("http://tumblr.com/oauth/request_token") # pyright: ignore[reportUnknownMemberType]
141
+
142
+ rich_print("Open the link below in your browser, and authorize this application.\nAfter authorizing, copy and paste the URL of the page you are redirected to below.")
143
+ authorization_url = session.authorization_url("http://tumblr.com/oauth/authorize") # pyright: ignore[reportUnknownMemberType]
144
+ (authorization_response,) = self.online_token_prompt(authorization_url, "full redirect URL")
145
+ session.parse_authorization_response(authorization_response)
146
+
147
+ access_token = session.fetch_access_token("http://tumblr.com/oauth/access_token") # pyright: ignore[reportUnknownMemberType]
148
+
149
+ self.tumblr.resource_owner_key = access_token["oauth_token"]
150
+ self.tumblr.resource_owner_secret = access_token["oauth_token_secret"]
151
+
152
+
153
+ class Blog(FullyValidatedModel):
154
+ name: str = ""
155
+ posts: int = 0
156
+ uuid: str = ""
157
+
158
+
159
+ class ResponseModel(FullyValidatedModel):
160
+ class Response(FullyValidatedModel):
161
+ blog: Blog = Blog()
162
+ posts: list[Any] = []
163
+
164
+ response: Response
165
+
166
+
167
+ class Block(FullyValidatedModel):
168
+ type: str = "text"
169
+ text: str = ""
170
+ blocks: list[int] = []
171
+
172
+
173
+ class Post(FullyValidatedModel):
174
+ blog: Blog = Blog()
175
+ id: int = 0
176
+ parent_tumblelog_uuid: str = ""
177
+ parent_post_id: int = 0
178
+ reblog_key: str = ""
179
+
180
+ timestamp: int = 0
181
+ tags: Annotated[list[str], PlainSerializer(",".join)] = []
182
+ state: Literal["published", "queued", "draft", "private", "unapproved"] = "draft"
183
+
184
+ content: list[Block] = []
185
+ layout: list[Block] = []
186
+ trail: list[Self] = []
187
+
188
+ is_submission: SkipJsonSchema[bool] = False
189
+
190
+ def __rich__(self) -> Panel:
191
+ return Panel(
192
+ str(self),
193
+ title="Preview",
194
+ subtitle=" ".join(f"#{tag}" for tag in self.tags),
195
+ subtitle_align="left",
196
+ )
197
+
198
+ def __str__(self) -> str:
199
+ # This function is really only relevant when a post is already valid, so we don't have to check the block types.
200
+ # If it is called on an invalid post, it would also work, but might give strange data.
201
+ return "\n\n".join(block.text for block in self.content)
202
+
203
+ def valid_text_post(self) -> bool:
204
+ # Checks if this post:
205
+ # - has any content blocks (some glitched empty posts have no content)
206
+ # - only has content blocks of type 'text' (this excludes photo/video/poll/etc posts)
207
+ # - is not a submitted post
208
+ # - has no ask blocks in the content
209
+ return bool(self.content) and all(block.type == "text" for block in self.content) and not (self.is_submission or any(block.type == "ask" for block in self.layout))
210
+
211
+
212
+ class Message(FullyValidatedModel):
213
+ role: Literal["developer", "user", "assistant"]
214
+ content: str
215
+
216
+
217
+ class Example(FullyValidatedModel):
218
+ messages: list[Message]
219
+
220
+ def get_assistant_message(self) -> str:
221
+ for message in self.messages:
222
+ if message.role == "assistant":
223
+ return message.content
224
+ msg = "Assistant message not found!"
225
+ raise ValueError(msg)
tumblrbot/utils/tumblr.py CHANGED
@@ -1,64 +1,83 @@
1
- from typing import Self
2
-
3
- from requests import HTTPError, Response
4
- from requests_oauthlib import OAuth1Session
5
- from rich import print as rich_print
6
- from tenacity import retry, retry_if_exception_message, stop_after_attempt, wait_random_exponential
7
-
8
- from tumblrbot.utils.models import Post, ResponseModel, Tokens
9
-
10
- rate_limit_retry = retry(
11
- stop=stop_after_attempt(10),
12
- wait=wait_random_exponential(min=60),
13
- retry=retry_if_exception_message(match="429 Client Error: Limit Exceeded for url: .+"),
14
- before_sleep=lambda state: rich_print(f"[yellow]Tumblr rate limit exceeded. Waiting for {state.idle_for} seconds..."),
15
- reraise=True,
16
- )
17
-
18
-
19
- class TumblrSession(OAuth1Session):
20
- def __init__(self, tokens: Tokens) -> None:
21
- super().__init__(**tokens.tumblr.model_dump()) # pyright: ignore[reportUnknownMemberType]
22
- self.hooks["response"].append(self.response_hook)
23
-
24
- def __enter__(self) -> Self:
25
- super().__enter__()
26
- return self
27
-
28
- def response_hook(self, response: Response, *_args: object, **_kwargs: object) -> None:
29
- try:
30
- response.raise_for_status()
31
- except HTTPError as error:
32
- error.add_note(response.text)
33
- raise
34
-
35
- @rate_limit_retry
36
- def retrieve_blog_info(self, blog_identifier: str) -> ResponseModel:
37
- response = self.get(f"https://api.tumblr.com/v2/blog/{blog_identifier}/info")
38
- return ResponseModel.model_validate_json(response.text)
39
-
40
- @rate_limit_retry
41
- def retrieve_published_posts(
42
- self,
43
- blog_identifier: str,
44
- offset: int | None = None,
45
- after: int | None = None,
46
- ) -> ResponseModel:
47
- response = self.get(
48
- f"https://api.tumblr.com/v2/blog/{blog_identifier}/posts",
49
- params={
50
- "offset": offset,
51
- "after": after,
52
- "sort": "asc",
53
- "npf": True,
54
- },
55
- )
56
- return ResponseModel.model_validate_json(response.text)
57
-
58
- @rate_limit_retry
59
- def create_post(self, blog_identifier: str, post: Post) -> ResponseModel:
60
- response = self.post(
61
- f"https://api.tumblr.com/v2/blog/{blog_identifier}/posts",
62
- json=post.model_dump(),
63
- )
64
- return ResponseModel.model_validate_json(response.text)
1
+ from locale import str as locale_str
2
+
3
+ from requests import HTTPError, Response, Session
4
+ from requests_oauthlib import OAuth1
5
+ from rich import print as rich_print
6
+ from rich.pretty import pprint
7
+ from tenacity import RetryCallState, retry, retry_if_exception_message
8
+
9
+ from tumblrbot.utils.models import Post, ResponseModel, Tokens
10
+
11
+
12
+ def wait_until_ratelimit_reset(retry_state: RetryCallState) -> float:
13
+ if retry_state.outcome is not None:
14
+ exception = retry_state.outcome.exception()
15
+ if isinstance(exception, HTTPError):
16
+ ratelimit_type = "day" if exception.response.headers["X-Ratelimit-Perday-Remaining"] == "0" else "hour"
17
+ return float(exception.response.headers[f"X-Ratelimit-Per{ratelimit_type}-Reset"])
18
+ return 0
19
+
20
+
21
+ rate_limit_retry = retry(
22
+ wait=wait_until_ratelimit_reset,
23
+ retry=retry_if_exception_message(match="429 Client Error: Limit Exceeded for url: .+"),
24
+ before_sleep=lambda state: rich_print(f"[yellow]Tumblr rate limit exceeded. Waiting for {locale_str(state.upcoming_sleep)} seconds..."),
25
+ )
26
+
27
+
28
+ class TumblrSession(Session):
29
+ def __init__(self, tokens: Tokens) -> None:
30
+ super().__init__()
31
+ self.auth = OAuth1(**tokens.tumblr.model_dump())
32
+ self.hooks["response"].append(self.response_hook)
33
+
34
+ self.api_key = tokens.tumblr.client_key
35
+
36
+ def response_hook(self, response: Response, *_args: object, **_kwargs: object) -> None:
37
+ try:
38
+ response.raise_for_status()
39
+ except HTTPError as error:
40
+ for error_msg in response.json()["errors"]:
41
+ error.add_note(f"{error_msg['code']}: {error_msg['detail']}")
42
+ raise
43
+
44
+ @rate_limit_retry
45
+ def retrieve_blog_info(self, blog_identifier: str) -> ResponseModel:
46
+ response = self.get(
47
+ f"https://api.tumblr.com/v2/blog/{blog_identifier}/info",
48
+ params={
49
+ "api_key": self.api_key,
50
+ },
51
+ )
52
+ return ResponseModel.model_validate_json(response.text)
53
+
54
+ @rate_limit_retry
55
+ def retrieve_published_posts(
56
+ self,
57
+ blog_identifier: str,
58
+ offset: int | None = None,
59
+ after: int | None = None,
60
+ ) -> ResponseModel:
61
+ response = self.get(
62
+ f"https://api.tumblr.com/v2/blog/{blog_identifier}/posts",
63
+ params={
64
+ "api_key": self.api_key,
65
+ "offset": offset,
66
+ "after": after,
67
+ "sort": "asc",
68
+ "npf": True,
69
+ },
70
+ )
71
+ try:
72
+ return ResponseModel.model_validate_json(response.text)
73
+ except:
74
+ pprint(response.headers)
75
+ raise
76
+
77
+ @rate_limit_retry
78
+ def create_post(self, blog_identifier: str, post: Post) -> ResponseModel:
79
+ response = self.post(
80
+ f"https://api.tumblr.com/v2/blog/{blog_identifier}/posts",
81
+ json=post.model_dump(),
82
+ )
83
+ return ResponseModel.model_validate_json(response.text)