tumblrbot 1.9.6__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tumblrbot/__main__.py CHANGED
@@ -1,43 +1,47 @@
1
- from sys import exit as sys_exit
2
-
3
- from openai import OpenAI
4
- from rich.prompt import Confirm
5
- from rich.traceback import install
6
-
7
- from tumblrbot.flow.download import PostDownloader
8
- from tumblrbot.flow.examples import ExamplesWriter
9
- from tumblrbot.flow.fine_tune import FineTuner
10
- from tumblrbot.flow.generate import DraftGenerator
11
- from tumblrbot.utils.common import FlowClass
12
- from tumblrbot.utils.models import Tokens
13
- from tumblrbot.utils.tumblr import TumblrSession
14
-
15
-
16
- def main() -> None:
17
- install()
18
-
19
- tokens = Tokens.load()
20
- with OpenAI(api_key=tokens.openai_api_key) as openai, TumblrSession(tokens) as tumblr:
21
- if Confirm.ask("Download latest posts?", default=False):
22
- PostDownloader(openai=openai, tumblr=tumblr).main()
23
-
24
- examples_writer = ExamplesWriter(openai=openai, tumblr=tumblr)
25
- if Confirm.ask("Create training data?", default=False):
26
- examples_writer.main()
27
-
28
- if Confirm.ask("Remove training data flagged by the OpenAI moderation? [bold]This can sometimes resolve errors with fine-tuning validation, but is slow.", default=False):
29
- examples_writer.filter_examples()
30
-
31
- fine_tuner = FineTuner(openai=openai, tumblr=tumblr)
32
- fine_tuner.print_estimates()
33
-
34
- message = "Resume monitoring the previous fine-tuning process?" if FlowClass.config.job_id else "Upload data to OpenAI for fine-tuning?"
35
- if Confirm.ask(f"{message} [bold]You must do this to set the model to generate drafts from. Alternatively, manually enter a model into the config", default=False):
36
- fine_tuner.main()
37
-
38
- if Confirm.ask("Generate drafts?", default=False):
39
- DraftGenerator(openai=openai, tumblr=tumblr).main()
40
-
41
-
42
- if __name__ == "__main__":
43
- sys_exit(main())
1
+ from locale import LC_ALL, setlocale
2
+ from sys import exit as sys_exit
3
+ from sys import maxsize
4
+
5
+ from openai import OpenAI
6
+ from rich.prompt import Confirm
7
+ from rich.traceback import install
8
+
9
+ from tumblrbot.flow.download import PostDownloader
10
+ from tumblrbot.flow.examples import ExamplesWriter
11
+ from tumblrbot.flow.fine_tune import FineTuner
12
+ from tumblrbot.flow.generate import DraftGenerator
13
+ from tumblrbot.utils.common import FlowClass
14
+ from tumblrbot.utils.models import Tokens
15
+ from tumblrbot.utils.tumblr import TumblrSession
16
+
17
+
18
+ def main() -> None:
19
+ setlocale(LC_ALL, "")
20
+
21
+ install()
22
+
23
+ tokens = Tokens.load()
24
+ with OpenAI(api_key=tokens.openai_api_key, max_retries=maxsize) as openai, TumblrSession(tokens) as tumblr:
25
+ if Confirm.ask("Download latest posts?", default=False):
26
+ PostDownloader(openai, tumblr).main()
27
+
28
+ examples_writer = ExamplesWriter(openai, tumblr)
29
+ if Confirm.ask("Create training data?", default=False):
30
+ examples_writer.main()
31
+
32
+ if Confirm.ask("Remove training data flagged by the OpenAI moderation? [bold]This can sometimes resolve errors with fine-tuning validation, but is slow.", default=False):
33
+ examples_writer.filter_examples()
34
+
35
+ fine_tuner = FineTuner(openai, tumblr)
36
+ fine_tuner.print_estimates()
37
+
38
+ message = "Resume monitoring the previous fine-tuning process?" if FlowClass.config.job_id else "Upload data to OpenAI for fine-tuning?"
39
+ if Confirm.ask(f"{message} [bold]You must do this to set the model to generate drafts from. Alternatively, manually enter a model into the config", default=False):
40
+ fine_tuner.main()
41
+
42
+ if Confirm.ask("Generate drafts?", default=False):
43
+ DraftGenerator(openai, tumblr).main()
44
+
45
+
46
+ if __name__ == "__main__":
47
+ sys_exit(main())
@@ -1,55 +1,55 @@
1
- from json import dump
2
- from typing import TYPE_CHECKING, override
3
-
4
- from tumblrbot.utils.common import FlowClass, PreviewLive
5
- from tumblrbot.utils.models import Post
6
-
7
- if TYPE_CHECKING:
8
- from io import TextIOBase
9
-
10
-
11
- class PostDownloader(FlowClass):
12
- @override
13
- def main(self) -> None:
14
- self.config.data_directory.mkdir(parents=True, exist_ok=True)
15
-
16
- with PreviewLive() as live:
17
- for blog_identifier in self.config.download_blog_identifiers:
18
- data_path = self.get_data_path(blog_identifier)
19
-
20
- completed = 0
21
- after = 0
22
- if data_path.exists():
23
- lines = data_path.read_bytes().splitlines() if data_path.exists() else []
24
- completed = len(lines)
25
- if lines:
26
- after = Post.model_validate_json(lines[-1]).timestamp
27
-
28
- with data_path.open("a", encoding="utf_8") as fp:
29
- self.paginate_posts(
30
- blog_identifier,
31
- completed,
32
- after,
33
- fp,
34
- live,
35
- )
36
-
37
- def paginate_posts(self, blog_identifier: str, completed: int, after: int, fp: TextIOBase, live: PreviewLive) -> None:
38
- task_id = live.progress.add_task(f"Downloading posts from '{blog_identifier}'...", total=None, completed=completed)
39
-
40
- while True:
41
- response = self.tumblr.retrieve_published_posts(blog_identifier, after=after)
42
- live.progress.update(task_id, total=response.response.blog.posts, completed=completed)
43
-
44
- if not response.response.posts:
45
- return
46
-
47
- for post in response.response.posts:
48
- dump(post, fp)
49
- fp.write("\n")
50
-
51
- model = Post.model_validate(post)
52
- after = model.timestamp
53
- live.custom_update(model)
54
-
55
- completed += len(response.response.posts)
1
+ from json import dump
2
+ from typing import TYPE_CHECKING, override
3
+
4
+ from tumblrbot.utils.common import FlowClass, PreviewLive
5
+ from tumblrbot.utils.models import Post
6
+
7
+ if TYPE_CHECKING:
8
+ from io import TextIOBase
9
+
10
+
11
+ class PostDownloader(FlowClass):
12
+ @override
13
+ def main(self) -> None:
14
+ self.config.data_directory.mkdir(parents=True, exist_ok=True)
15
+
16
+ with PreviewLive() as live:
17
+ for blog_identifier in self.config.download_blog_identifiers:
18
+ data_path = self.get_data_path(blog_identifier)
19
+
20
+ completed = 0
21
+ after = 0
22
+ if data_path.exists():
23
+ lines = data_path.read_bytes().splitlines() if data_path.exists() else []
24
+ completed = len(lines)
25
+ if lines:
26
+ after = Post.model_validate_json(lines[-1]).timestamp
27
+
28
+ with data_path.open("a", encoding="utf_8") as fp:
29
+ self.paginate_posts(
30
+ blog_identifier,
31
+ completed,
32
+ after,
33
+ fp,
34
+ live,
35
+ )
36
+
37
+ def paginate_posts(self, blog_identifier: str, completed: int, after: int, fp: TextIOBase, live: PreviewLive) -> None:
38
+ task_id = live.progress.add_task(f"Downloading posts from '{blog_identifier}'...", total=None, completed=completed)
39
+
40
+ while True:
41
+ response = self.tumblr.retrieve_published_posts(blog_identifier, after=after)
42
+ live.progress.update(task_id, total=response.response.blog.posts, completed=completed)
43
+
44
+ if not response.response.posts:
45
+ return
46
+
47
+ for post in response.response.posts:
48
+ dump(post, fp)
49
+ fp.write("\n")
50
+
51
+ model = Post.model_validate(post)
52
+ after = model.timestamp
53
+ live.custom_update(model)
54
+
55
+ completed += len(response.response.posts)
@@ -1,97 +1,94 @@
1
- from collections.abc import Generator
2
- from itertools import batched
3
- from json import loads
4
- from math import ceil
5
- from re import IGNORECASE
6
- from re import compile as re_compile
7
- from typing import TYPE_CHECKING, override
8
-
9
- from openai import RateLimitError
10
- from rich import print as rich_print
11
- from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential
12
-
13
- from tumblrbot.utils.common import FlowClass, PreviewLive
14
- from tumblrbot.utils.models import Example, Message, Post
15
-
16
- if TYPE_CHECKING:
17
- from collections.abc import Generator, Iterable
18
- from pathlib import Path
19
-
20
- from openai._types import SequenceNotStr
21
- from openai.types import ModerationCreateResponse, ModerationMultiModalInputParam
22
-
23
-
24
- class ExamplesWriter(FlowClass):
25
- @override
26
- def main(self) -> None:
27
- self.config.examples_file.parent.mkdir(parents=True, exist_ok=True)
28
-
29
- examples = [self.create_example(*prompt) for prompt in self.get_custom_prompts()]
30
- examples.extend(self.create_example(self.config.user_message, str(post)) for post in self.get_valid_posts())
31
- self.write_examples(examples)
32
-
33
- rich_print(f"[bold]The examples file can be found at: '{self.config.examples_file}'\n")
34
-
35
- def create_example(self, user_message: str, assistant_message: str) -> Example:
36
- return Example(
37
- messages=[
38
- Message(role="developer", content=self.config.developer_message),
39
- Message(role="user", content=user_message),
40
- Message(role="assistant", content=assistant_message),
41
- ],
42
- )
43
-
44
- def get_custom_prompts(self) -> Generator[tuple[str, str]]:
45
- self.config.custom_prompts_file.parent.mkdir(parents=True, exist_ok=True)
46
- self.config.custom_prompts_file.touch(exist_ok=True)
47
-
48
- with self.config.custom_prompts_file.open("rb") as fp:
49
- for line in fp:
50
- data: dict[str, str] = loads(line)
51
- yield from data.items()
52
-
53
- # This function mostly exists to make writing examples atomic.
54
- def write_examples(self, examples: Iterable[Example]) -> None:
55
- with self.config.examples_file.open("w", encoding="utf_8") as fp:
56
- for example in examples:
57
- fp.write(f"{example.model_dump_json()}\n")
58
-
59
- def get_valid_posts(self) -> Generator[Post]:
60
- for path in self.get_data_paths():
61
- posts = list(self.get_valid_posts_from_path(path))
62
- yield from posts[-self.config.post_limit :]
63
-
64
- def get_valid_posts_from_path(self, path: Path) -> Generator[Post]:
65
- pattern = re_compile("|".join(self.config.filtered_words), IGNORECASE)
66
- with path.open("rb") as fp:
67
- for line in fp:
68
- post = Post.model_validate_json(line)
69
- if post.valid_text_post() and not (post.trail and self.config.filtered_words and pattern.search(str(post))):
70
- yield post
71
-
72
- def filter_examples(self) -> None:
73
- raw_examples = self.config.examples_file.read_bytes().splitlines()
74
- old_examples = map(Example.model_validate_json, raw_examples)
75
- new_examples: list[Example] = []
76
- with PreviewLive() as live:
77
- for batch in live.progress.track(
78
- batched(old_examples, self.config.moderation_batch_size, strict=False),
79
- ceil(len(raw_examples) / self.config.moderation_batch_size),
80
- description="Removing flagged posts...",
81
- ):
82
- response = self.create_moderation_batch(tuple(map(Example.get_assistant_message, batch)))
83
- new_examples.extend(example for example, moderation in zip(batch, response.results, strict=True) if not moderation.flagged)
84
-
85
- self.write_examples(new_examples)
86
-
87
- rich_print(f"[red]Removed {len(raw_examples) - len(new_examples)} posts.\n")
88
-
89
- @retry(
90
- stop=stop_after_attempt(10),
91
- wait=wait_random_exponential(),
92
- retry=retry_if_exception_type(RateLimitError),
93
- before_sleep=lambda state: rich_print(f"[yellow]OpenAI rate limit exceeded. Waiting for {state.idle_for} seconds..."),
94
- reraise=True,
95
- )
96
- def create_moderation_batch(self, api_input: str | SequenceNotStr[str] | Iterable[ModerationMultiModalInputParam]) -> ModerationCreateResponse:
97
- return self.openai.moderations.create(input=api_input)
1
+ from collections.abc import Generator
2
+ from itertools import batched
3
+ from json import loads
4
+ from math import ceil
5
+ from re import IGNORECASE
6
+ from re import compile as re_compile
7
+ from typing import TYPE_CHECKING, override
8
+
9
+ from rich import print as rich_print
10
+ from rich.console import Console
11
+
12
+ from tumblrbot.utils.common import FlowClass, PreviewLive, localize_number
13
+ from tumblrbot.utils.models import Example, Message, Post
14
+
15
+ if TYPE_CHECKING:
16
+ from collections.abc import Generator, Iterable
17
+ from pathlib import Path
18
+
19
+ from openai._types import SequenceNotStr
20
+ from openai.types import ModerationCreateResponse, ModerationMultiModalInputParam
21
+
22
+
23
+ class ExamplesWriter(FlowClass):
24
+ @override
25
+ def main(self) -> None:
26
+ self.config.examples_file.parent.mkdir(parents=True, exist_ok=True)
27
+
28
+ examples = [self.create_example(*prompt) for prompt in self.get_custom_prompts()]
29
+ examples.extend(self.create_example(self.config.user_message, str(post)) for post in self.get_valid_posts())
30
+ self.write_examples(examples)
31
+
32
+ rich_print(f"[bold]The examples file can be found at: '{self.config.examples_file}'\n")
33
+
34
+ def create_example(self, user_message: str, assistant_message: str) -> Example:
35
+ return Example(
36
+ messages=[
37
+ Message(role="developer", content=self.config.developer_message),
38
+ Message(role="user", content=user_message),
39
+ Message(role="assistant", content=assistant_message),
40
+ ],
41
+ )
42
+
43
+ def get_custom_prompts(self) -> Generator[tuple[str, str]]:
44
+ self.config.custom_prompts_file.parent.mkdir(parents=True, exist_ok=True)
45
+ self.config.custom_prompts_file.touch(exist_ok=True)
46
+
47
+ with self.config.custom_prompts_file.open("rb") as fp:
48
+ for line in fp:
49
+ data: dict[str, str] = loads(line)
50
+ yield from data.items()
51
+
52
+ # This function mostly exists to make writing examples (mostly) atomic.
53
+ # If there is an error dumping the models or writing to the file,
54
+ # then it will leave a partially written or empty file behind.
55
+ def write_examples(self, examples: Iterable[Example]) -> None:
56
+ with self.config.examples_file.open("w", encoding="utf_8") as fp:
57
+ for example in examples:
58
+ fp.write(f"{example.model_dump_json()}\n")
59
+
60
+ def get_valid_posts(self) -> Generator[Post]:
61
+ for path in self.get_data_paths():
62
+ if path.exists():
63
+ posts = list(self.get_valid_posts_from_path(path))
64
+ yield from posts[-self.config.post_limit :]
65
+ else:
66
+ Console(stderr=True, style="logging.level.warning").print(f"{path} does not exist!")
67
+
68
+ def get_valid_posts_from_path(self, path: Path) -> Generator[Post]:
69
+ pattern = re_compile("|".join(self.config.filtered_words), IGNORECASE)
70
+ with path.open("rb") as fp:
71
+ for line in fp:
72
+ post = Post.model_validate_json(line)
73
+ if post.valid_text_post() and not (post.trail and self.config.filtered_words and pattern.search(str(post))):
74
+ yield post
75
+
76
+ def filter_examples(self) -> None:
77
+ raw_examples = self.config.examples_file.read_bytes().splitlines()
78
+ old_examples = map(Example.model_validate_json, raw_examples)
79
+ new_examples: list[Example] = []
80
+ with PreviewLive() as live:
81
+ for batch in live.progress.track(
82
+ batched(old_examples, self.config.moderation_batch_size, strict=False),
83
+ ceil(len(raw_examples) / self.config.moderation_batch_size),
84
+ description="Removing flagged posts...",
85
+ ):
86
+ response = self.create_moderation_batch(tuple(map(Example.get_assistant_message, batch)))
87
+ new_examples.extend(example for example, moderation in zip(batch, response.results, strict=True) if not moderation.flagged)
88
+
89
+ self.write_examples(new_examples)
90
+
91
+ rich_print(f"[red]Removed {localize_number(len(raw_examples) - len(new_examples))} posts.\n")
92
+
93
+ def create_moderation_batch(self, api_input: str | SequenceNotStr[str] | Iterable[ModerationMultiModalInputParam]) -> ModerationCreateResponse:
94
+ return self.openai.moderations.create(input=api_input)