tumblrbot 1.9.7__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tumblrbot/__main__.py +47 -43
- tumblrbot/flow/download.py +55 -55
- tumblrbot/flow/examples.py +94 -97
- tumblrbot/flow/fine_tune.py +141 -137
- tumblrbot/flow/generate.py +97 -97
- tumblrbot/utils/common.py +62 -57
- tumblrbot/utils/models.py +225 -225
- tumblrbot/utils/tumblr.py +83 -68
- {tumblrbot-1.9.7.dist-info → tumblrbot-1.10.1.dist-info}/METADATA +27 -40
- tumblrbot-1.10.1.dist-info/RECORD +16 -0
- {tumblrbot-1.9.7.dist-info → tumblrbot-1.10.1.dist-info}/WHEEL +2 -1
- tumblrbot-1.10.1.dist-info/entry_points.txt +2 -0
- tumblrbot-1.10.1.dist-info/top_level.txt +1 -0
- tumblrbot-1.9.7.dist-info/RECORD +0 -15
- tumblrbot-1.9.7.dist-info/entry_points.txt +0 -3
tumblrbot/utils/models.py
CHANGED
|
@@ -1,225 +1,225 @@
|
|
|
1
|
-
from getpass import getpass
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
from tomllib import loads
|
|
4
|
-
from typing import TYPE_CHECKING, Annotated, Any, Literal, Self, override
|
|
5
|
-
|
|
6
|
-
from openai.types import
|
|
7
|
-
from pydantic import BaseModel, ConfigDict, Field, NonNegativeFloat, NonNegativeInt, PlainSerializer, PositiveFloat, PositiveInt, model_validator
|
|
8
|
-
from pydantic.json_schema import SkipJsonSchema # noqa: TC002
|
|
9
|
-
from requests_oauthlib import OAuth1Session
|
|
10
|
-
from rich import print as rich_print
|
|
11
|
-
from rich.panel import Panel
|
|
12
|
-
from rich.prompt import Prompt
|
|
13
|
-
from tomlkit import comment, document, dumps # pyright: ignore[reportUnknownVariableType]
|
|
14
|
-
|
|
15
|
-
if TYPE_CHECKING:
|
|
16
|
-
from collections.abc import Generator
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class FullyValidatedModel(BaseModel):
|
|
20
|
-
model_config = ConfigDict(
|
|
21
|
-
extra="ignore",
|
|
22
|
-
validate_assignment=True,
|
|
23
|
-
validate_default=True,
|
|
24
|
-
validate_return=True,
|
|
25
|
-
validate_by_name=True,
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class FileSyncSettings(FullyValidatedModel):
|
|
30
|
-
@classmethod
|
|
31
|
-
def get_toml_file(cls) -> Path:
|
|
32
|
-
return Path(f"{cls.__name__.lower()}.toml")
|
|
33
|
-
|
|
34
|
-
@classmethod
|
|
35
|
-
def load(cls) -> Self:
|
|
36
|
-
toml_file = cls.get_toml_file()
|
|
37
|
-
data = loads(toml_file.read_text("utf_8")) if toml_file.exists() else {}
|
|
38
|
-
return cls.model_validate(data)
|
|
39
|
-
|
|
40
|
-
@model_validator(mode="after")
|
|
41
|
-
def dump(self) -> Self:
|
|
42
|
-
toml_table = document()
|
|
43
|
-
|
|
44
|
-
for (name, field), value in zip(self.__class__.model_fields.items(), self.model_dump(mode="json").values(), strict=True):
|
|
45
|
-
if field.description is not None:
|
|
46
|
-
for line in field.description.split(". "):
|
|
47
|
-
toml_table.add(comment(f"{line.removesuffix('.')}."))
|
|
48
|
-
|
|
49
|
-
toml_table[name] = value
|
|
50
|
-
|
|
51
|
-
self.get_toml_file().write_text(dumps(toml_table), encoding="utf_8")
|
|
52
|
-
|
|
53
|
-
return self
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class Config(FileSyncSettings):
|
|
57
|
-
# Downloading Posts & Writing Examples
|
|
58
|
-
download_blog_identifiers: list[str] = Field([], description="The identifiers of the blogs which post data will be downloaded from.")
|
|
59
|
-
data_directory: Path = Field(Path("data"), description="Where to store downloaded post data.")
|
|
60
|
-
|
|
61
|
-
# Writing Examples
|
|
62
|
-
post_limit: NonNegativeInt = Field(0, description="The number of the most recent posts from each blog that should be included in the training data.")
|
|
63
|
-
moderation_batch_size: PositiveInt = Field(25, description="The number of posts at a time to submit to the OpenAI moderation API.")
|
|
64
|
-
custom_prompts_file: Path = Field(Path("custom_prompts.jsonl"), description="Where to read in custom prompts from.")
|
|
65
|
-
filtered_words: list[str] = Field([], description="A case-insensitive list of disallowed words used to filter out training data. Regular expressions are allowed, but must be escaped.")
|
|
66
|
-
|
|
67
|
-
# Writing Examples & Fine-Tuning
|
|
68
|
-
examples_file: Path = Field(Path("examples.jsonl"), description="Where to output the examples that will be used to fine-tune the model.")
|
|
69
|
-
|
|
70
|
-
# Writing Examples & Generating
|
|
71
|
-
developer_message: str = Field("You are a Tumblr post bot. Please generate a Tumblr post in accordance with the user's request.", description="The developer message used by the OpenAI API to generate drafts.")
|
|
72
|
-
user_message: str = Field("Please write a comical Tumblr post.", description="The user input used by the OpenAI API to generate drafts.")
|
|
73
|
-
|
|
74
|
-
# Fine-Tuning
|
|
75
|
-
expected_epochs: PositiveInt = Field(3, description="The expected number of epochs fine-tuning will be run for. This will be updated during fine-tuning.")
|
|
76
|
-
token_price: PositiveFloat = Field(3, description="The expected price in USD per million tokens during fine-tuning for the current model.")
|
|
77
|
-
job_id: str = Field("", description="The fine-tuning job ID that will be polled on next run.")
|
|
78
|
-
|
|
79
|
-
# Fine-Tuning & Generating
|
|
80
|
-
base_model:
|
|
81
|
-
fine_tuned_model: str = Field("", description="The name of the OpenAI model that was fine-tuned with your posts.")
|
|
82
|
-
|
|
83
|
-
# Generating
|
|
84
|
-
upload_blog_identifier: str = Field("", description="The identifier of the blog which generated drafts will be uploaded to. This must be a blog associated with the same account as the configured Tumblr secret tokens.")
|
|
85
|
-
draft_count: PositiveInt = Field(100, description="The number of drafts to process. This will affect the number of tokens used with OpenAI")
|
|
86
|
-
tags_chance: NonNegativeFloat = Field(0.1, description="The chance to generate tags for any given post. This will use more OpenAI tokens.")
|
|
87
|
-
tags_developer_message: str = Field("You will be provided with a block of text, and your task is to extract a very short list of the most important subjects from it.", description="The developer message used to generate tags.")
|
|
88
|
-
reblog_blog_identifiers: list[str] = Field([], description="The identifiers of blogs that can be reblogged from when generating drafts.")
|
|
89
|
-
reblog_chance: NonNegativeFloat = Field(0.1, description="The chance to generate a reblog of a random post. This will use more OpenAI tokens.")
|
|
90
|
-
reblog_user_message: str = Field("Please write a comical Tumblr post in response to the following post:\n\n{}", description="The format string for the user message used to reblog posts.")
|
|
91
|
-
|
|
92
|
-
@override
|
|
93
|
-
def model_post_init(self, context: object) -> None:
|
|
94
|
-
super().model_post_init(context)
|
|
95
|
-
|
|
96
|
-
if not self.download_blog_identifiers:
|
|
97
|
-
rich_print("Enter the [cyan]identifiers of your blogs[/] that data should be [bold purple]downloaded[/] from, separated by commas.")
|
|
98
|
-
self.download_blog_identifiers = list(map(str.strip, Prompt.ask("[bold][Example] [dim]staff.tumblr.com,changes").split(",")))
|
|
99
|
-
|
|
100
|
-
if not self.upload_blog_identifier:
|
|
101
|
-
rich_print("Enter the [cyan]identifier of your blog[/] that drafts should be [bold purple]uploaded[/] to.")
|
|
102
|
-
self.upload_blog_identifier = Prompt.ask("[bold][Example] [dim]staff.tumblr.com or changes").strip()
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
class Tokens(FileSyncSettings):
|
|
106
|
-
class Tumblr(FullyValidatedModel):
|
|
107
|
-
client_key: str = ""
|
|
108
|
-
client_secret: str = ""
|
|
109
|
-
resource_owner_key: str = ""
|
|
110
|
-
resource_owner_secret: str = ""
|
|
111
|
-
|
|
112
|
-
openai_api_key: str = ""
|
|
113
|
-
tumblr: Tumblr = Tumblr()
|
|
114
|
-
|
|
115
|
-
@staticmethod
|
|
116
|
-
def online_token_prompt(url: str, *tokens: str) -> Generator[str]:
|
|
117
|
-
formatted_token_string = " and ".join(f"[cyan]{token}[/]" for token in tokens)
|
|
118
|
-
|
|
119
|
-
rich_print(f"Retrieve your {formatted_token_string} from: {url}")
|
|
120
|
-
for token in tokens:
|
|
121
|
-
yield getpass(f"Enter your {token} (masked): ", echo_char="*").strip()
|
|
122
|
-
|
|
123
|
-
rich_print()
|
|
124
|
-
|
|
125
|
-
@override
|
|
126
|
-
def model_post_init(self, context: object) -> None:
|
|
127
|
-
super().model_post_init(context)
|
|
128
|
-
|
|
129
|
-
# Check if any tokens are missing or if the user wants to reset them, then set tokens if necessary.
|
|
130
|
-
if not self.openai_api_key:
|
|
131
|
-
(self.openai_api_key,) = self.online_token_prompt("https://platform.openai.com/api-keys", "API key")
|
|
132
|
-
|
|
133
|
-
if not all(self.tumblr.model_dump().values()):
|
|
134
|
-
self.tumblr.client_key, self.tumblr.client_secret = self.online_token_prompt("https://tumblr.com/oauth/apps", "consumer key", "consumer secret")
|
|
135
|
-
|
|
136
|
-
# This is the whole OAuth 1.0 process.
|
|
137
|
-
# https://requests-oauthlib.readthedocs.io/en/latest/examples/tumblr.html
|
|
138
|
-
# We tried setting up OAuth 2.0, but the token refresh process is far too unreliable for this sort of program.
|
|
139
|
-
with OAuth1Session(**self.tumblr.model_dump()) as session:
|
|
140
|
-
session.fetch_request_token("http://tumblr.com/oauth/request_token") # pyright: ignore[reportUnknownMemberType]
|
|
141
|
-
|
|
142
|
-
rich_print("Open the link below in your browser, and authorize this application.\nAfter authorizing, copy and paste the URL of the page you are redirected to below.")
|
|
143
|
-
authorization_url = session.authorization_url("http://tumblr.com/oauth/authorize") # pyright: ignore[reportUnknownMemberType]
|
|
144
|
-
(authorization_response,) = self.online_token_prompt(authorization_url, "full redirect URL")
|
|
145
|
-
session.parse_authorization_response(authorization_response)
|
|
146
|
-
|
|
147
|
-
access_token = session.fetch_access_token("http://tumblr.com/oauth/access_token") # pyright: ignore[reportUnknownMemberType]
|
|
148
|
-
|
|
149
|
-
self.tumblr.resource_owner_key = access_token["oauth_token"]
|
|
150
|
-
self.tumblr.resource_owner_secret = access_token["oauth_token_secret"]
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
class Blog(FullyValidatedModel):
|
|
154
|
-
name: str = ""
|
|
155
|
-
posts: int = 0
|
|
156
|
-
uuid: str = ""
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
class ResponseModel(FullyValidatedModel):
|
|
160
|
-
class Response(FullyValidatedModel):
|
|
161
|
-
blog: Blog = Blog()
|
|
162
|
-
posts: list[Any] = []
|
|
163
|
-
|
|
164
|
-
response: Response
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
class Block(FullyValidatedModel):
|
|
168
|
-
type: str = "text"
|
|
169
|
-
text: str = ""
|
|
170
|
-
blocks: list[int] = []
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
class Post(FullyValidatedModel):
|
|
174
|
-
blog: Blog = Blog()
|
|
175
|
-
id: int = 0
|
|
176
|
-
parent_tumblelog_uuid: str = ""
|
|
177
|
-
parent_post_id: int = 0
|
|
178
|
-
reblog_key: str = ""
|
|
179
|
-
|
|
180
|
-
timestamp: int = 0
|
|
181
|
-
tags: Annotated[list[str], PlainSerializer(",".join)] = []
|
|
182
|
-
state: Literal["published", "queued", "draft", "private", "unapproved"] = "draft"
|
|
183
|
-
|
|
184
|
-
content: list[Block] = []
|
|
185
|
-
layout: list[Block] = []
|
|
186
|
-
trail: list[Self] = []
|
|
187
|
-
|
|
188
|
-
is_submission: SkipJsonSchema[bool] = False
|
|
189
|
-
|
|
190
|
-
def __rich__(self) -> Panel:
|
|
191
|
-
return Panel(
|
|
192
|
-
str(self),
|
|
193
|
-
title="Preview",
|
|
194
|
-
subtitle=" ".join(f"#{tag}" for tag in self.tags),
|
|
195
|
-
subtitle_align="left",
|
|
196
|
-
)
|
|
197
|
-
|
|
198
|
-
def __str__(self) -> str:
|
|
199
|
-
# This function is really only relevant when a post is already valid, so we don't have to check the block types.
|
|
200
|
-
# If it is called on an invalid post, it would also work, but might give strange data.
|
|
201
|
-
return "\n\n".join(block.text for block in self.content)
|
|
202
|
-
|
|
203
|
-
def valid_text_post(self) -> bool:
|
|
204
|
-
# Checks if this post:
|
|
205
|
-
# - has any content blocks (some glitched empty posts have no content)
|
|
206
|
-
# - only has content blocks of type 'text' (this excludes photo/video/poll/etc posts)
|
|
207
|
-
# - is not a submitted post
|
|
208
|
-
# - has no ask blocks in the content
|
|
209
|
-
return bool(self.content) and all(block.type == "text" for block in self.content) and not (self.is_submission or any(block.type == "ask" for block in self.layout))
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
class Message(FullyValidatedModel):
|
|
213
|
-
role: Literal["developer", "user", "assistant"]
|
|
214
|
-
content: str
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
class Example(FullyValidatedModel):
|
|
218
|
-
messages: list[Message]
|
|
219
|
-
|
|
220
|
-
def get_assistant_message(self) -> str:
|
|
221
|
-
for message in self.messages:
|
|
222
|
-
if message.role == "assistant":
|
|
223
|
-
return message.content
|
|
224
|
-
msg = "Assistant message not found!"
|
|
225
|
-
raise ValueError(msg)
|
|
1
|
+
from getpass import getpass
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from tomllib import loads
|
|
4
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal, Self, override
|
|
5
|
+
|
|
6
|
+
from openai.types import ResponsesModel # noqa: TC002
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, NonNegativeFloat, NonNegativeInt, PlainSerializer, PositiveFloat, PositiveInt, model_validator
|
|
8
|
+
from pydantic.json_schema import SkipJsonSchema # noqa: TC002
|
|
9
|
+
from requests_oauthlib import OAuth1Session
|
|
10
|
+
from rich import print as rich_print
|
|
11
|
+
from rich.panel import Panel
|
|
12
|
+
from rich.prompt import Prompt
|
|
13
|
+
from tomlkit import comment, document, dumps # pyright: ignore[reportUnknownVariableType]
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from collections.abc import Generator
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FullyValidatedModel(BaseModel):
|
|
20
|
+
model_config = ConfigDict(
|
|
21
|
+
extra="ignore",
|
|
22
|
+
validate_assignment=True,
|
|
23
|
+
validate_default=True,
|
|
24
|
+
validate_return=True,
|
|
25
|
+
validate_by_name=True,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class FileSyncSettings(FullyValidatedModel):
|
|
30
|
+
@classmethod
|
|
31
|
+
def get_toml_file(cls) -> Path:
|
|
32
|
+
return Path(f"{cls.__name__.lower()}.toml")
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def load(cls) -> Self:
|
|
36
|
+
toml_file = cls.get_toml_file()
|
|
37
|
+
data = loads(toml_file.read_text("utf_8")) if toml_file.exists() else {}
|
|
38
|
+
return cls.model_validate(data)
|
|
39
|
+
|
|
40
|
+
@model_validator(mode="after")
|
|
41
|
+
def dump(self) -> Self:
|
|
42
|
+
toml_table = document()
|
|
43
|
+
|
|
44
|
+
for (name, field), value in zip(self.__class__.model_fields.items(), self.model_dump(mode="json").values(), strict=True):
|
|
45
|
+
if field.description is not None:
|
|
46
|
+
for line in field.description.split(". "):
|
|
47
|
+
toml_table.add(comment(f"{line.removesuffix('.')}."))
|
|
48
|
+
|
|
49
|
+
toml_table[name] = value
|
|
50
|
+
|
|
51
|
+
self.get_toml_file().write_text(dumps(toml_table), encoding="utf_8")
|
|
52
|
+
|
|
53
|
+
return self
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Config(FileSyncSettings):
|
|
57
|
+
# Downloading Posts & Writing Examples
|
|
58
|
+
download_blog_identifiers: list[str] = Field([], description="The identifiers of the blogs which post data will be downloaded from.")
|
|
59
|
+
data_directory: Path = Field(Path("data"), description="Where to store downloaded post data.")
|
|
60
|
+
|
|
61
|
+
# Writing Examples
|
|
62
|
+
post_limit: NonNegativeInt = Field(0, description="The number of the most recent posts from each blog that should be included in the training data.")
|
|
63
|
+
moderation_batch_size: PositiveInt = Field(25, description="The number of posts at a time to submit to the OpenAI moderation API.")
|
|
64
|
+
custom_prompts_file: Path = Field(Path("custom_prompts.jsonl"), description="Where to read in custom prompts from.")
|
|
65
|
+
filtered_words: list[str] = Field([], description="A case-insensitive list of disallowed words used to filter out training data. Regular expressions are allowed, but must be escaped.")
|
|
66
|
+
|
|
67
|
+
# Writing Examples & Fine-Tuning
|
|
68
|
+
examples_file: Path = Field(Path("examples.jsonl"), description="Where to output the examples that will be used to fine-tune the model.")
|
|
69
|
+
|
|
70
|
+
# Writing Examples & Generating
|
|
71
|
+
developer_message: str = Field("You are a Tumblr post bot. Please generate a Tumblr post in accordance with the user's request.", description="The developer message used by the OpenAI API to generate drafts.")
|
|
72
|
+
user_message: str = Field("Please write a comical Tumblr post.", description="The user input used by the OpenAI API to generate drafts.")
|
|
73
|
+
|
|
74
|
+
# Fine-Tuning
|
|
75
|
+
expected_epochs: PositiveInt = Field(3, description="The expected number of epochs fine-tuning will be run for. This will be updated during fine-tuning.")
|
|
76
|
+
token_price: PositiveFloat = Field(3, description="The expected price in USD per million tokens during fine-tuning for the current model.")
|
|
77
|
+
job_id: str = Field("", description="The fine-tuning job ID that will be polled on next run.")
|
|
78
|
+
|
|
79
|
+
# Fine-Tuning & Generating
|
|
80
|
+
base_model: ResponsesModel = Field("gpt-4o-mini-2024-07-18", description="The name of the model that will be fine-tuned by the generated training data.")
|
|
81
|
+
fine_tuned_model: str = Field("", description="The name of the OpenAI model that was fine-tuned with your posts.")
|
|
82
|
+
|
|
83
|
+
# Generating
|
|
84
|
+
upload_blog_identifier: str = Field("", description="The identifier of the blog which generated drafts will be uploaded to. This must be a blog associated with the same account as the configured Tumblr secret tokens.")
|
|
85
|
+
draft_count: PositiveInt = Field(100, description="The number of drafts to process. This will affect the number of tokens used with OpenAI")
|
|
86
|
+
tags_chance: NonNegativeFloat = Field(0.1, description="The chance to generate tags for any given post. This will use more OpenAI tokens.")
|
|
87
|
+
tags_developer_message: str = Field("You will be provided with a block of text, and your task is to extract a very short list of the most important subjects from it.", description="The developer message used to generate tags.")
|
|
88
|
+
reblog_blog_identifiers: list[str] = Field([], description="The identifiers of blogs that can be reblogged from when generating drafts.")
|
|
89
|
+
reblog_chance: NonNegativeFloat = Field(0.1, description="The chance to generate a reblog of a random post. This will use more OpenAI tokens.")
|
|
90
|
+
reblog_user_message: str = Field("Please write a comical Tumblr post in response to the following post:\n\n{}", description="The format string for the user message used to reblog posts.")
|
|
91
|
+
|
|
92
|
+
@override
|
|
93
|
+
def model_post_init(self, context: object) -> None:
|
|
94
|
+
super().model_post_init(context)
|
|
95
|
+
|
|
96
|
+
if not self.download_blog_identifiers:
|
|
97
|
+
rich_print("Enter the [cyan]identifiers of your blogs[/] that data should be [bold purple]downloaded[/] from, separated by commas.")
|
|
98
|
+
self.download_blog_identifiers = list(map(str.strip, Prompt.ask("[bold][Example] [dim]staff.tumblr.com,changes").split(",")))
|
|
99
|
+
|
|
100
|
+
if not self.upload_blog_identifier:
|
|
101
|
+
rich_print("Enter the [cyan]identifier of your blog[/] that drafts should be [bold purple]uploaded[/] to.")
|
|
102
|
+
self.upload_blog_identifier = Prompt.ask("[bold][Example] [dim]staff.tumblr.com or changes").strip()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class Tokens(FileSyncSettings):
|
|
106
|
+
class Tumblr(FullyValidatedModel):
|
|
107
|
+
client_key: str = ""
|
|
108
|
+
client_secret: str = ""
|
|
109
|
+
resource_owner_key: str = ""
|
|
110
|
+
resource_owner_secret: str = ""
|
|
111
|
+
|
|
112
|
+
openai_api_key: str = ""
|
|
113
|
+
tumblr: Tumblr = Tumblr()
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def online_token_prompt(url: str, *tokens: str) -> Generator[str]:
|
|
117
|
+
formatted_token_string = " and ".join(f"[cyan]{token}[/]" for token in tokens)
|
|
118
|
+
|
|
119
|
+
rich_print(f"Retrieve your {formatted_token_string} from: {url}")
|
|
120
|
+
for token in tokens:
|
|
121
|
+
yield getpass(f"Enter your {token} (masked): ", echo_char="*").strip()
|
|
122
|
+
|
|
123
|
+
rich_print()
|
|
124
|
+
|
|
125
|
+
@override
|
|
126
|
+
def model_post_init(self, context: object) -> None:
|
|
127
|
+
super().model_post_init(context)
|
|
128
|
+
|
|
129
|
+
# Check if any tokens are missing or if the user wants to reset them, then set tokens if necessary.
|
|
130
|
+
if not self.openai_api_key:
|
|
131
|
+
(self.openai_api_key,) = self.online_token_prompt("https://platform.openai.com/api-keys", "API key")
|
|
132
|
+
|
|
133
|
+
if not all(self.tumblr.model_dump().values()):
|
|
134
|
+
self.tumblr.client_key, self.tumblr.client_secret = self.online_token_prompt("https://tumblr.com/oauth/apps", "consumer key", "consumer secret")
|
|
135
|
+
|
|
136
|
+
# This is the whole OAuth 1.0 process.
|
|
137
|
+
# https://requests-oauthlib.readthedocs.io/en/latest/examples/tumblr.html
|
|
138
|
+
# We tried setting up OAuth 2.0, but the token refresh process is far too unreliable for this sort of program.
|
|
139
|
+
with OAuth1Session(**self.tumblr.model_dump()) as session:
|
|
140
|
+
session.fetch_request_token("http://tumblr.com/oauth/request_token") # pyright: ignore[reportUnknownMemberType]
|
|
141
|
+
|
|
142
|
+
rich_print("Open the link below in your browser, and authorize this application.\nAfter authorizing, copy and paste the URL of the page you are redirected to below.")
|
|
143
|
+
authorization_url = session.authorization_url("http://tumblr.com/oauth/authorize") # pyright: ignore[reportUnknownMemberType]
|
|
144
|
+
(authorization_response,) = self.online_token_prompt(authorization_url, "full redirect URL")
|
|
145
|
+
session.parse_authorization_response(authorization_response)
|
|
146
|
+
|
|
147
|
+
access_token = session.fetch_access_token("http://tumblr.com/oauth/access_token") # pyright: ignore[reportUnknownMemberType]
|
|
148
|
+
|
|
149
|
+
self.tumblr.resource_owner_key = access_token["oauth_token"]
|
|
150
|
+
self.tumblr.resource_owner_secret = access_token["oauth_token_secret"]
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class Blog(FullyValidatedModel):
|
|
154
|
+
name: str = ""
|
|
155
|
+
posts: int = 0
|
|
156
|
+
uuid: str = ""
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class ResponseModel(FullyValidatedModel):
|
|
160
|
+
class Response(FullyValidatedModel):
|
|
161
|
+
blog: Blog = Blog()
|
|
162
|
+
posts: list[Any] = []
|
|
163
|
+
|
|
164
|
+
response: Response
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class Block(FullyValidatedModel):
|
|
168
|
+
type: str = "text"
|
|
169
|
+
text: str = ""
|
|
170
|
+
blocks: list[int] = []
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class Post(FullyValidatedModel):
|
|
174
|
+
blog: Blog = Blog()
|
|
175
|
+
id: int = 0
|
|
176
|
+
parent_tumblelog_uuid: str = ""
|
|
177
|
+
parent_post_id: int = 0
|
|
178
|
+
reblog_key: str = ""
|
|
179
|
+
|
|
180
|
+
timestamp: int = 0
|
|
181
|
+
tags: Annotated[list[str], PlainSerializer(",".join)] = []
|
|
182
|
+
state: Literal["published", "queued", "draft", "private", "unapproved"] = "draft"
|
|
183
|
+
|
|
184
|
+
content: list[Block] = []
|
|
185
|
+
layout: list[Block] = []
|
|
186
|
+
trail: list[Self] = []
|
|
187
|
+
|
|
188
|
+
is_submission: SkipJsonSchema[bool] = False
|
|
189
|
+
|
|
190
|
+
def __rich__(self) -> Panel:
|
|
191
|
+
return Panel(
|
|
192
|
+
str(self),
|
|
193
|
+
title="Preview",
|
|
194
|
+
subtitle=" ".join(f"#{tag}" for tag in self.tags),
|
|
195
|
+
subtitle_align="left",
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def __str__(self) -> str:
|
|
199
|
+
# This function is really only relevant when a post is already valid, so we don't have to check the block types.
|
|
200
|
+
# If it is called on an invalid post, it would also work, but might give strange data.
|
|
201
|
+
return "\n\n".join(block.text for block in self.content)
|
|
202
|
+
|
|
203
|
+
def valid_text_post(self) -> bool:
|
|
204
|
+
# Checks if this post:
|
|
205
|
+
# - has any content blocks (some glitched empty posts have no content)
|
|
206
|
+
# - only has content blocks of type 'text' (this excludes photo/video/poll/etc posts)
|
|
207
|
+
# - is not a submitted post
|
|
208
|
+
# - has no ask blocks in the content
|
|
209
|
+
return bool(self.content) and all(block.type == "text" for block in self.content) and not (self.is_submission or any(block.type == "ask" for block in self.layout))
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class Message(FullyValidatedModel):
|
|
213
|
+
role: Literal["developer", "user", "assistant"]
|
|
214
|
+
content: str
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class Example(FullyValidatedModel):
|
|
218
|
+
messages: list[Message]
|
|
219
|
+
|
|
220
|
+
def get_assistant_message(self) -> str:
|
|
221
|
+
for message in self.messages:
|
|
222
|
+
if message.role == "assistant":
|
|
223
|
+
return message.content
|
|
224
|
+
msg = "Assistant message not found!"
|
|
225
|
+
raise ValueError(msg)
|
tumblrbot/utils/tumblr.py
CHANGED
|
@@ -1,68 +1,83 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
from
|
|
4
|
-
from
|
|
5
|
-
|
|
6
|
-
from
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
1
|
+
from locale import str as locale_str
|
|
2
|
+
|
|
3
|
+
from requests import HTTPError, Response, Session
|
|
4
|
+
from requests_oauthlib import OAuth1
|
|
5
|
+
from rich import print as rich_print
|
|
6
|
+
from rich.pretty import pprint
|
|
7
|
+
from tenacity import RetryCallState, retry, retry_if_exception_message
|
|
8
|
+
|
|
9
|
+
from tumblrbot.utils.models import Post, ResponseModel, Tokens
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def wait_until_ratelimit_reset(retry_state: RetryCallState) -> float:
|
|
13
|
+
if retry_state.outcome is not None:
|
|
14
|
+
exception = retry_state.outcome.exception()
|
|
15
|
+
if isinstance(exception, HTTPError):
|
|
16
|
+
ratelimit_type = "day" if exception.response.headers["X-Ratelimit-Perday-Remaining"] == "0" else "hour"
|
|
17
|
+
return float(exception.response.headers[f"X-Ratelimit-Per{ratelimit_type}-Reset"])
|
|
18
|
+
return 0
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
rate_limit_retry = retry(
|
|
22
|
+
wait=wait_until_ratelimit_reset,
|
|
23
|
+
retry=retry_if_exception_message(match="429 Client Error: Limit Exceeded for url: .+"),
|
|
24
|
+
before_sleep=lambda state: rich_print(f"[yellow]Tumblr rate limit exceeded. Waiting for {locale_str(state.upcoming_sleep)} seconds..."),
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TumblrSession(Session):
|
|
29
|
+
def __init__(self, tokens: Tokens) -> None:
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.auth = OAuth1(**tokens.tumblr.model_dump())
|
|
32
|
+
self.hooks["response"].append(self.response_hook)
|
|
33
|
+
|
|
34
|
+
self.api_key = tokens.tumblr.client_key
|
|
35
|
+
|
|
36
|
+
def response_hook(self, response: Response, *_args: object, **_kwargs: object) -> None:
|
|
37
|
+
try:
|
|
38
|
+
response.raise_for_status()
|
|
39
|
+
except HTTPError as error:
|
|
40
|
+
for error_msg in response.json()["errors"]:
|
|
41
|
+
error.add_note(f"{error_msg['code']}: {error_msg['detail']}")
|
|
42
|
+
raise
|
|
43
|
+
|
|
44
|
+
@rate_limit_retry
|
|
45
|
+
def retrieve_blog_info(self, blog_identifier: str) -> ResponseModel:
|
|
46
|
+
response = self.get(
|
|
47
|
+
f"https://api.tumblr.com/v2/blog/{blog_identifier}/info",
|
|
48
|
+
params={
|
|
49
|
+
"api_key": self.api_key,
|
|
50
|
+
},
|
|
51
|
+
)
|
|
52
|
+
return ResponseModel.model_validate_json(response.text)
|
|
53
|
+
|
|
54
|
+
@rate_limit_retry
|
|
55
|
+
def retrieve_published_posts(
|
|
56
|
+
self,
|
|
57
|
+
blog_identifier: str,
|
|
58
|
+
offset: int | None = None,
|
|
59
|
+
after: int | None = None,
|
|
60
|
+
) -> ResponseModel:
|
|
61
|
+
response = self.get(
|
|
62
|
+
f"https://api.tumblr.com/v2/blog/{blog_identifier}/posts",
|
|
63
|
+
params={
|
|
64
|
+
"api_key": self.api_key,
|
|
65
|
+
"offset": offset,
|
|
66
|
+
"after": after,
|
|
67
|
+
"sort": "asc",
|
|
68
|
+
"npf": True,
|
|
69
|
+
},
|
|
70
|
+
)
|
|
71
|
+
try:
|
|
72
|
+
return ResponseModel.model_validate_json(response.text)
|
|
73
|
+
except:
|
|
74
|
+
pprint(response.headers)
|
|
75
|
+
raise
|
|
76
|
+
|
|
77
|
+
@rate_limit_retry
|
|
78
|
+
def create_post(self, blog_identifier: str, post: Post) -> ResponseModel:
|
|
79
|
+
response = self.post(
|
|
80
|
+
f"https://api.tumblr.com/v2/blog/{blog_identifier}/posts",
|
|
81
|
+
json=post.model_dump(),
|
|
82
|
+
)
|
|
83
|
+
return ResponseModel.model_validate_json(response.text)
|