diversify-text 0.1.2__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diversify_text-0.1.2 → diversify_text-0.2.0}/PKG-INFO +28 -13
- {diversify_text-0.1.2 → diversify_text-0.2.0}/README.md +25 -10
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/__init__.py +2 -0
- diversify_text-0.2.0/diversify_text/_cache.py +281 -0
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/_output.py +7 -7
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/_postprocess.py +6 -6
- diversify_text-0.2.0/diversify_text/_utils.py +65 -0
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/core.py +126 -53
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/filter/mis.py +30 -12
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/method/base.py +2 -2
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/method/echo.py +2 -2
- diversify_text-0.2.0/diversify_text/method/prompting/__init__.py +6 -0
- diversify_text-0.2.0/diversify_text/method/prompting/method.py +413 -0
- diversify_text-0.2.0/diversify_text/method/prompting/model.py +205 -0
- diversify_text-0.2.0/diversify_text/method/prompting/prompts.py +258 -0
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/method/registry.py +2 -0
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/method/tinystyler/__init__.py +1 -1
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/method/tinystyler/method.py +27 -45
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/method/tinystyler/model.py +14 -14
- {diversify_text-0.1.2/diversify_text/method/tinystyler → diversify_text-0.2.0/diversify_text}/styles.py +46 -8
- {diversify_text-0.1.2 → diversify_text-0.2.0}/pyproject.toml +4 -2
- diversify_text-0.1.2/diversify_text/_utils.py +0 -27
- {diversify_text-0.1.2 → diversify_text-0.2.0}/.gitignore +0 -0
- {diversify_text-0.1.2 → diversify_text-0.2.0}/LICENSE +0 -0
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/_input.py +0 -0
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/_preprocess.py +0 -0
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/filter/__init__.py +0 -0
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/method/__init__.py +0 -0
- {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/py.typed +0 -0
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: diversify-text
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: Generate stylistic paraphrases of texts using local transformer models.
|
|
5
5
|
Project-URL: Homepage, https://github.com/AnnaWegmann/diversify_text
|
|
6
6
|
Project-URL: Documentation, https://annawegmann.github.io/diversify_text/
|
|
7
7
|
Project-URL: Repository, https://github.com/AnnaWegmann/diversify_text
|
|
8
8
|
Project-URL: Issues, https://github.com/AnnaWegmann/diversify_text/issues
|
|
9
|
-
Author: Anna Wegmann
|
|
9
|
+
Author: Anna Wegmann, Eduardo Calò, Menan Velayuthan
|
|
10
10
|
License-Expression: MIT
|
|
11
11
|
License-File: LICENSE
|
|
12
12
|
Keywords: augmentation,nlp,paraphrase,style-transfer,text-generation
|
|
@@ -26,7 +26,7 @@ Requires-Dist: pysbd>=0.3.4
|
|
|
26
26
|
Requires-Dist: sentence-transformers
|
|
27
27
|
Requires-Dist: sentencepiece
|
|
28
28
|
Requires-Dist: tiktoken
|
|
29
|
-
Requires-Dist: torch
|
|
29
|
+
Requires-Dist: torch>=2.10.0
|
|
30
30
|
Requires-Dist: tqdm>=4.67.3
|
|
31
31
|
Requires-Dist: transformers>=5.3.0
|
|
32
32
|
Description-Content-Type: text/markdown
|
|
@@ -46,6 +46,7 @@ pip install diversify-text
|
|
|
46
46
|
- [Usage](#usage)
|
|
47
47
|
- [Single text](#single-text)
|
|
48
48
|
- [Control number of paraphrases](#control-number-of-paraphrases)
|
|
49
|
+
- [Caching](#caching)
|
|
49
50
|
- [Using the class directly](#using-the-class-directly)
|
|
50
51
|
- [List of texts](#list-of-texts)
|
|
51
52
|
- [Customising the TinyStyler style bank](#customising-the-tinystyler-style-bank)
|
|
@@ -84,24 +85,38 @@ results = diversify("The experiment was conducted in a controlled lab setting.")
|
|
|
84
85
|
### Control number of paraphrases
|
|
85
86
|
|
|
86
87
|
```python
|
|
87
|
-
results = diversify("Some text.",
|
|
88
|
+
results = diversify("Some text.", n=3)
|
|
88
89
|
```
|
|
89
90
|
|
|
90
91
|
```
|
|
91
92
|
[{"original": "Some text.", "paraphrases": ["...", "...", "..."]}]
|
|
92
93
|
```
|
|
93
94
|
|
|
95
|
+
### Caching
|
|
96
|
+
|
|
97
|
+
The `diversify()` function automatically caches loaded models between calls.
|
|
98
|
+
The generation model and the semantic filter are cached independently, so
|
|
99
|
+
toggling `semantic_filter` does not reload the generation model and vice
|
|
100
|
+
versa. Call `clear_cache()` to drop cached models and allow memory to be reclaimed when possible:
|
|
101
|
+
|
|
102
|
+
```python
|
|
103
|
+
from diversify_text import clear_cache
|
|
104
|
+
|
|
105
|
+
clear_cache()
|
|
106
|
+
```
|
|
107
|
+
|
|
94
108
|
### Using the class directly
|
|
95
109
|
|
|
96
|
-
|
|
110
|
+
You can also instantiate a `Diversifier` yourself for full control over the
|
|
111
|
+
model lifecycle:
|
|
97
112
|
|
|
98
113
|
```python
|
|
99
114
|
from diversify_text import Diversifier
|
|
100
115
|
|
|
101
116
|
div = Diversifier(device="cuda", methods=["tinystyler"])
|
|
102
117
|
|
|
103
|
-
batch_1 = div.diversify(texts_1,
|
|
104
|
-
batch_2 = div.diversify(texts_2,
|
|
118
|
+
batch_1 = div.diversify(texts_1, n=5)
|
|
119
|
+
batch_2 = div.diversify(texts_2, n=5)
|
|
105
120
|
```
|
|
106
121
|
|
|
107
122
|
### List of texts
|
|
@@ -130,7 +145,7 @@ A style bank can be a `dict[str, list[str]]` or a `list[list[str]]`:
|
|
|
130
145
|
|
|
131
146
|
```python
|
|
132
147
|
from diversify_text import diversify
|
|
133
|
-
from diversify_text.
|
|
148
|
+
from diversify_text.styles import DEFAULT_STYLE_BANK
|
|
134
149
|
|
|
135
150
|
custom_bank = {
|
|
136
151
|
"academic": ["The results demonstrate a statistically significant effect."],
|
|
@@ -144,10 +159,10 @@ results = diversify(
|
|
|
144
159
|
)
|
|
145
160
|
```
|
|
146
161
|
|
|
147
|
-
`DEFAULT_STYLE_BANK` is exported from `diversify_text.
|
|
162
|
+
`DEFAULT_STYLE_BANK` is exported from `diversify_text.styles` so you can build on it:
|
|
148
163
|
|
|
149
164
|
```python
|
|
150
|
-
from diversify_text.
|
|
165
|
+
from diversify_text.styles import DEFAULT_STYLE_BANK
|
|
151
166
|
|
|
152
167
|
extended_bank = {
|
|
153
168
|
**DEFAULT_STYLE_BANK,
|
|
@@ -175,11 +190,11 @@ from diversify_text.method import DiversificationMethod
|
|
|
175
190
|
class MyMethod(DiversificationMethod):
|
|
176
191
|
name = "my_method"
|
|
177
192
|
|
|
178
|
-
def generate(self, texts, *,
|
|
179
|
-
return [[f"{text} :: variant {i}" for i in range(
|
|
193
|
+
def generate(self, texts, *, n, max_new_tokens, temperature, top_p, **kwargs):
|
|
194
|
+
return [[f"{text} :: variant {i}" for i in range(n)] for text in texts]
|
|
180
195
|
|
|
181
196
|
|
|
182
|
-
results = Diversifier(methods=[MyMethod()]).diversify("Hello",
|
|
197
|
+
results = Diversifier(methods=[MyMethod()]).diversify("Hello", n=3)
|
|
183
198
|
```
|
|
184
199
|
|
|
185
200
|
```
|
|
@@ -13,6 +13,7 @@ pip install diversify-text
|
|
|
13
13
|
- [Usage](#usage)
|
|
14
14
|
- [Single text](#single-text)
|
|
15
15
|
- [Control number of paraphrases](#control-number-of-paraphrases)
|
|
16
|
+
- [Caching](#caching)
|
|
16
17
|
- [Using the class directly](#using-the-class-directly)
|
|
17
18
|
- [List of texts](#list-of-texts)
|
|
18
19
|
- [Customising the TinyStyler style bank](#customising-the-tinystyler-style-bank)
|
|
@@ -51,24 +52,38 @@ results = diversify("The experiment was conducted in a controlled lab setting.")
|
|
|
51
52
|
### Control number of paraphrases
|
|
52
53
|
|
|
53
54
|
```python
|
|
54
|
-
results = diversify("Some text.",
|
|
55
|
+
results = diversify("Some text.", n=3)
|
|
55
56
|
```
|
|
56
57
|
|
|
57
58
|
```
|
|
58
59
|
[{"original": "Some text.", "paraphrases": ["...", "...", "..."]}]
|
|
59
60
|
```
|
|
60
61
|
|
|
62
|
+
### Caching
|
|
63
|
+
|
|
64
|
+
The `diversify()` function automatically caches loaded models between calls.
|
|
65
|
+
The generation model and the semantic filter are cached independently, so
|
|
66
|
+
toggling `semantic_filter` does not reload the generation model and vice
|
|
67
|
+
versa. Call `clear_cache()` to drop cached models and allow memory to be reclaimed when possible:
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
from diversify_text import clear_cache
|
|
71
|
+
|
|
72
|
+
clear_cache()
|
|
73
|
+
```
|
|
74
|
+
|
|
61
75
|
### Using the class directly
|
|
62
76
|
|
|
63
|
-
|
|
77
|
+
You can also instantiate a `Diversifier` yourself for full control over the
|
|
78
|
+
model lifecycle:
|
|
64
79
|
|
|
65
80
|
```python
|
|
66
81
|
from diversify_text import Diversifier
|
|
67
82
|
|
|
68
83
|
div = Diversifier(device="cuda", methods=["tinystyler"])
|
|
69
84
|
|
|
70
|
-
batch_1 = div.diversify(texts_1,
|
|
71
|
-
batch_2 = div.diversify(texts_2,
|
|
85
|
+
batch_1 = div.diversify(texts_1, n=5)
|
|
86
|
+
batch_2 = div.diversify(texts_2, n=5)
|
|
72
87
|
```
|
|
73
88
|
|
|
74
89
|
### List of texts
|
|
@@ -97,7 +112,7 @@ A style bank can be a `dict[str, list[str]]` or a `list[list[str]]`:
|
|
|
97
112
|
|
|
98
113
|
```python
|
|
99
114
|
from diversify_text import diversify
|
|
100
|
-
from diversify_text.
|
|
115
|
+
from diversify_text.styles import DEFAULT_STYLE_BANK
|
|
101
116
|
|
|
102
117
|
custom_bank = {
|
|
103
118
|
"academic": ["The results demonstrate a statistically significant effect."],
|
|
@@ -111,10 +126,10 @@ results = diversify(
|
|
|
111
126
|
)
|
|
112
127
|
```
|
|
113
128
|
|
|
114
|
-
`DEFAULT_STYLE_BANK` is exported from `diversify_text.
|
|
129
|
+
`DEFAULT_STYLE_BANK` is exported from `diversify_text.styles` so you can build on it:
|
|
115
130
|
|
|
116
131
|
```python
|
|
117
|
-
from diversify_text.
|
|
132
|
+
from diversify_text.styles import DEFAULT_STYLE_BANK
|
|
118
133
|
|
|
119
134
|
extended_bank = {
|
|
120
135
|
**DEFAULT_STYLE_BANK,
|
|
@@ -142,11 +157,11 @@ from diversify_text.method import DiversificationMethod
|
|
|
142
157
|
class MyMethod(DiversificationMethod):
|
|
143
158
|
name = "my_method"
|
|
144
159
|
|
|
145
|
-
def generate(self, texts, *,
|
|
146
|
-
return [[f"{text} :: variant {i}" for i in range(
|
|
160
|
+
def generate(self, texts, *, n, max_new_tokens, temperature, top_p, **kwargs):
|
|
161
|
+
return [[f"{text} :: variant {i}" for i in range(n)] for text in texts]
|
|
147
162
|
|
|
148
163
|
|
|
149
|
-
results = Diversifier(methods=[MyMethod()]).diversify("Hello",
|
|
164
|
+
results = Diversifier(methods=[MyMethod()]).diversify("Hello", n=3)
|
|
150
165
|
```
|
|
151
166
|
|
|
152
167
|
```
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
|
|
5
|
+
from diversify_text._cache import clear_cache
|
|
5
6
|
from diversify_text.core import (
|
|
6
7
|
Diversifier,
|
|
7
8
|
diversify,
|
|
@@ -9,6 +10,7 @@ from diversify_text.core import (
|
|
|
9
10
|
|
|
10
11
|
__all__ = [
|
|
11
12
|
"Diversifier",
|
|
13
|
+
"clear_cache",
|
|
12
14
|
"diversify",
|
|
13
15
|
]
|
|
14
16
|
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
"""Per-model caching for the :func:`~diversify_text.core.diversify` convenience function.
|
|
2
|
+
|
|
3
|
+
Keeps the generation method(s) and the MIS filter in independent
|
|
4
|
+
module-level caches so that toggling ``semantic_filter`` does not
|
|
5
|
+
reload the generation model, and switching methods does not reload the
|
|
6
|
+
MIS model.
|
|
7
|
+
|
|
8
|
+
Each generation method is cached individually so that adding, removing,
|
|
9
|
+
or reordering methods only (re)loads the ones whose configuration
|
|
10
|
+
actually changed.
|
|
11
|
+
|
|
12
|
+
Not thread-safe. Intended for single-threaded use in scripts and
|
|
13
|
+
notebooks. For multi-threaded applications, use :class:`Diversifier`
|
|
14
|
+
directly with your own instance management.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import inspect
|
|
20
|
+
from collections.abc import Mapping, Sequence
|
|
21
|
+
from functools import lru_cache
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from diversify_text._utils import default_device
|
|
25
|
+
from diversify_text.filter.mis import MISFilter, _DEFAULT_MIN_SCORE, _DEFAULT_N_CANDIDATES
|
|
26
|
+
from diversify_text.method import DEFAULT_METHOD_REGISTRY, DiversificationMethod
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# kwargs that affect model construction and should invalidate the cache.
|
|
30
|
+
# Per-call kwargs (styles, prompts, n_style_examples, etc.) are excluded.
|
|
31
|
+
_CACHE_KWARGS = {"model", "device", "precision"}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# ------------------------------------------------------------------
|
|
35
|
+
# Generation method cache (dict-based, one entry per method)
|
|
36
|
+
# ------------------------------------------------------------------
|
|
37
|
+
|
|
38
|
+
_METHOD_CACHE: dict[tuple, DiversificationMethod] = {}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _resolve_cache_kwargs(
|
|
42
|
+
method_name: str,
|
|
43
|
+
device: str,
|
|
44
|
+
method_kwargs: Mapping[str, dict[str, Any]] | None = None,
|
|
45
|
+
) -> dict[str, Any]:
|
|
46
|
+
"""Resolve the full set of cache-relevant kwargs for a method.
|
|
47
|
+
|
|
48
|
+
Merges caller-provided kwargs with the constructor's own defaults
|
|
49
|
+
(discovered via ``inspect.signature``) so that the cache key is
|
|
50
|
+
the same whether the caller explicitly passes a default value or
|
|
51
|
+
omits it. Only kwargs in :data:`_CACHE_KWARGS` are included.
|
|
52
|
+
|
|
53
|
+
For example, ``PromptingMethod.__init__`` has
|
|
54
|
+
``model="HuggingFaceTB/SmolLM3-3B"`` as a default.
|
|
55
|
+
These two calls should hit the same cache entry::
|
|
56
|
+
|
|
57
|
+
# Omit model — default is filled in from the signature.
|
|
58
|
+
get_methods(device=None, methods=["prompting"])
|
|
59
|
+
|
|
60
|
+
# Explicitly pass the same default value.
|
|
61
|
+
get_methods(device=None, methods=["prompting"],
|
|
62
|
+
method_kwargs={"prompting": {"model": "HuggingFaceTB/SmolLM3-3B"}})
|
|
63
|
+
|
|
64
|
+
Without this function the first call would produce the key
|
|
65
|
+
``("prompting", (("device", "cpu"),))`` (no model) and the second
|
|
66
|
+
``("prompting", (("device", "cpu"), ("model", "HuggingFaceTB/..."),))``
|
|
67
|
+
— different keys, two copies of the same model loaded.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
method_name : str
|
|
72
|
+
Registry name of the method (e.g. ``"tinystyler"``).
|
|
73
|
+
device : str
|
|
74
|
+
Torch device string (already resolved, never ``None``).
|
|
75
|
+
method_kwargs : mapping, optional
|
|
76
|
+
Per-method keyword arguments keyed by method name. Only the
|
|
77
|
+
entry for *method_name* is inspected.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
dict[str, Any]
|
|
82
|
+
The full set of cache-relevant kwargs, e.g.
|
|
83
|
+
``{"device": "cpu", "model": "HuggingFaceTB/SmolLM3-3B", "precision": "auto"}``.
|
|
84
|
+
"""
|
|
85
|
+
# Start with device (always present).
|
|
86
|
+
resolved: dict[str, Any] = {"device": device}
|
|
87
|
+
|
|
88
|
+
# Fill in constructor defaults from the method class signature.
|
|
89
|
+
method_class = DEFAULT_METHOD_REGISTRY.get(method_name)
|
|
90
|
+
signature = inspect.signature(method_class)
|
|
91
|
+
for param_name, param in signature.parameters.items():
|
|
92
|
+
# inspect.Parameter.empty is a sentinel meaning "no default value."
|
|
93
|
+
# We skip those — only fill in defaults that actually exist.
|
|
94
|
+
if (
|
|
95
|
+
param_name in _CACHE_KWARGS
|
|
96
|
+
and param_name not in resolved
|
|
97
|
+
and param.default is not inspect.Parameter.empty
|
|
98
|
+
):
|
|
99
|
+
resolved[param_name] = param.default
|
|
100
|
+
|
|
101
|
+
# Override defaults with caller-provided kwargs.
|
|
102
|
+
if method_kwargs and (method_name in method_kwargs):
|
|
103
|
+
for k, v in method_kwargs[method_name].items():
|
|
104
|
+
if k in _CACHE_KWARGS:
|
|
105
|
+
resolved[k] = v
|
|
106
|
+
|
|
107
|
+
return resolved
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _single_METHOD_CACHE_key(
|
|
111
|
+
method_name: str,
|
|
112
|
+
device: str,
|
|
113
|
+
method_kwargs: Mapping[str, dict[str, Any]] | None = None,
|
|
114
|
+
) -> tuple:
|
|
115
|
+
"""Build a hashable key for a single generation method.
|
|
116
|
+
|
|
117
|
+
The key uniquely identifies a loaded model instance. It includes
|
|
118
|
+
only constructor-level kwargs (``model``, ``device``, ``precision``)
|
|
119
|
+
that determine *which* model gets loaded. Per-call kwargs like
|
|
120
|
+
``styles`` or ``prompts`` are excluded — changing those should reuse
|
|
121
|
+
the same model, not trigger an expensive reload.
|
|
122
|
+
|
|
123
|
+
Constructor defaults are resolved via ``inspect.signature`` so that
|
|
124
|
+
explicitly passing a default value produces the same cache key as
|
|
125
|
+
omitting it entirely.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
method_name : str
|
|
130
|
+
Registry name of the method (e.g. ``"tinystyler"``).
|
|
131
|
+
device : str
|
|
132
|
+
Torch device string (already resolved, never ``None``).
|
|
133
|
+
method_kwargs : mapping, optional
|
|
134
|
+
Per-method keyword arguments keyed by method name. Only the
|
|
135
|
+
entry for *method_name* is inspected, and only cache
|
|
136
|
+
kwargs within that entry are included in the key.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
tuple
|
|
141
|
+
A hashable key, e.g.
|
|
142
|
+
``("prompting", "cpu", (("model", "default-model"), ("precision", "auto")))``.
|
|
143
|
+
"""
|
|
144
|
+
resolved = _resolve_cache_kwargs(method_name, device, method_kwargs)
|
|
145
|
+
constructor_kwargs = tuple(sorted(resolved.items()))
|
|
146
|
+
return method_name, constructor_kwargs
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def get_methods(
|
|
150
|
+
device: str | None,
|
|
151
|
+
methods: Sequence[str | DiversificationMethod] | None,
|
|
152
|
+
method_kwargs: Mapping[str, dict[str, Any]] | None = None,
|
|
153
|
+
) -> list[DiversificationMethod]:
|
|
154
|
+
"""Return cached generation methods, resolving only on config change.
|
|
155
|
+
|
|
156
|
+
Iterates the requested *methods* list and resolves each one
|
|
157
|
+
individually against a module-level dict cache. On a cache miss
|
|
158
|
+
the method is instantiated via the registry (expensive — may load
|
|
159
|
+
a model); on a hit the existing instance is reused.
|
|
160
|
+
|
|
161
|
+
Methods can be specified as strings (looked up in the registry) or
|
|
162
|
+
as pre-built :class:`DiversificationMethod` instances (passed
|
|
163
|
+
through as-is without caching, since they're already instantiated).
|
|
164
|
+
You can mix both in one call, e.g.
|
|
165
|
+
``methods=["tinystyler", my_custom_method]``.
|
|
166
|
+
|
|
167
|
+
Because each method is cached independently, adding or removing a
|
|
168
|
+
method from the list only loads the new ones — already-cached
|
|
169
|
+
methods are not affected.
|
|
170
|
+
|
|
171
|
+
Parameters
|
|
172
|
+
----------
|
|
173
|
+
device : str or None
|
|
174
|
+
Torch device. ``None`` resolves to :func:`default_device`.
|
|
175
|
+
methods : sequence of str or DiversificationMethod, optional
|
|
176
|
+
Method names and/or pre-built instances. Defaults to
|
|
177
|
+
``["tinystyler"]``.
|
|
178
|
+
method_kwargs : mapping, optional
|
|
179
|
+
Per-method keyword arguments keyed by method name, e.g.
|
|
180
|
+
``{"prompting": {"model": "gpt2"}}``. Constructor kwargs
|
|
181
|
+
(``model``, ``device``, ``precision``) affect the cache key;
|
|
182
|
+
per-call kwargs (``styles``, ``prompts``) do not.
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
list[DiversificationMethod]
|
|
187
|
+
Resolved method instances in the same order as *methods*.
|
|
188
|
+
"""
|
|
189
|
+
device = device or default_device()
|
|
190
|
+
if methods is None:
|
|
191
|
+
methods = ["tinystyler"]
|
|
192
|
+
|
|
193
|
+
result: list[DiversificationMethod] = []
|
|
194
|
+
for method in methods:
|
|
195
|
+
if isinstance(method, DiversificationMethod):
|
|
196
|
+
result.append(method)
|
|
197
|
+
elif isinstance(method, str):
|
|
198
|
+
key = _single_METHOD_CACHE_key(method, device, method_kwargs)
|
|
199
|
+
if key not in _METHOD_CACHE: # cache miss → resolve and store
|
|
200
|
+
resolve_kwargs: dict[str, Any] = {"device": device}
|
|
201
|
+
if method_kwargs and (method in method_kwargs):
|
|
202
|
+
resolve_kwargs.update(method_kwargs[method])
|
|
203
|
+
_METHOD_CACHE[key] = DEFAULT_METHOD_REGISTRY.resolve(
|
|
204
|
+
[method], **resolve_kwargs
|
|
205
|
+
)[0]
|
|
206
|
+
result.append(_METHOD_CACHE[key])
|
|
207
|
+
else:
|
|
208
|
+
raise TypeError(
|
|
209
|
+
"method must be str or DiversificationMethod instance."
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if not result:
|
|
213
|
+
raise ValueError("At least one method is required.")
|
|
214
|
+
return result
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# ------------------------------------------------------------------
|
|
218
|
+
# MIS filter cache (lru_cache for expensive model load, thin wrapper
|
|
219
|
+
# for cheap per-call settings like min_score and n_candidates)
|
|
220
|
+
# ------------------------------------------------------------------
|
|
221
|
+
|
|
222
|
+
@lru_cache(maxsize=1)
|
|
223
|
+
def _load_mis_filter(device: str) -> MISFilter:
|
|
224
|
+
"""Load the MIS filter model (expensive).
|
|
225
|
+
|
|
226
|
+
This is the expensive part — loading the model weights. The
|
|
227
|
+
``lru_cache`` decorator ensures this only runs once per last used device
|
|
228
|
+
string. Cheap per-call settings (``min_score``, ``n_candidates``)
|
|
229
|
+
are applied separately in :func:`get_cached_mis_filter`.
|
|
230
|
+
"""
|
|
231
|
+
return MISFilter(device=device)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def get_cached_mis_filter(
|
|
235
|
+
device: str | None,
|
|
236
|
+
**filter_kwargs: Any,
|
|
237
|
+
) -> MISFilter:
|
|
238
|
+
"""Return cached MIS filter, reloading only when *device* changes.
|
|
239
|
+
|
|
240
|
+
Thin wrapper around :func:`_load_mis_filter`. The model load is
|
|
241
|
+
cached (expensive); this function just applies cheap per-call
|
|
242
|
+
threshold settings on the existing instance. Changing
|
|
243
|
+
``min_score`` or ``n_candidates`` between calls does not trigger a
|
|
244
|
+
model reload — only a device change does.
|
|
245
|
+
|
|
246
|
+
Parameters
|
|
247
|
+
----------
|
|
248
|
+
device : str or None
|
|
249
|
+
Torch device. ``None`` resolves to :func:`default_device`.
|
|
250
|
+
**filter_kwargs
|
|
251
|
+
``min_score`` and ``n_candidates``. Missing keys reset to
|
|
252
|
+
their defaults so that omitting a kwarg doesn't leave a stale
|
|
253
|
+
value from a previous call.
|
|
254
|
+
"""
|
|
255
|
+
device = device or default_device()
|
|
256
|
+
mis_filter = _load_mis_filter(device)
|
|
257
|
+
mis_filter.min_score = filter_kwargs.get("min_score", _DEFAULT_MIN_SCORE)
|
|
258
|
+
mis_filter.n_candidates = filter_kwargs.get("n_candidates", _DEFAULT_N_CANDIDATES)
|
|
259
|
+
return mis_filter
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
# ------------------------------------------------------------------
|
|
263
|
+
# Cache management
|
|
264
|
+
# ------------------------------------------------------------------
|
|
265
|
+
|
|
266
|
+
def clear_cache() -> None:
|
|
267
|
+
"""Drop references to all cached models so their memory can be reclaimed when possible.
|
|
268
|
+
|
|
269
|
+
Clears both the generation method dict cache and the ``lru_cache``
|
|
270
|
+
backing the MIS filter. After calling this, the next
|
|
271
|
+
:func:`get_methods` or :func:`get_cached_mis_filter` call will
|
|
272
|
+
load models from scratch.
|
|
273
|
+
|
|
274
|
+
This clears Python-level references but does not guarantee immediate
|
|
275
|
+
GPU/CPU memory release (e.g., allocator pools may retain reserved
|
|
276
|
+
memory).
|
|
277
|
+
"""
|
|
278
|
+
global _METHOD_CACHE
|
|
279
|
+
|
|
280
|
+
_METHOD_CACHE = {}
|
|
281
|
+
_load_mis_filter.cache_clear()
|
|
@@ -128,7 +128,7 @@ class OutputWriter:
|
|
|
128
128
|
def __init__(
|
|
129
129
|
self,
|
|
130
130
|
input_context: InputContext,
|
|
131
|
-
|
|
131
|
+
n: int,
|
|
132
132
|
output_path: Path | None,
|
|
133
133
|
) -> None:
|
|
134
134
|
"""Initialize the writer.
|
|
@@ -137,14 +137,14 @@ class OutputWriter:
|
|
|
137
137
|
----------
|
|
138
138
|
input_context : InputContext
|
|
139
139
|
Metadata about the input source (kind, path, etc.).
|
|
140
|
-
|
|
140
|
+
n : int
|
|
141
141
|
Number of paraphrase styles requested per text.
|
|
142
142
|
output_path : Path or None
|
|
143
143
|
Where to write results on disk. ``None`` means results
|
|
144
144
|
are kept in memory and returned as ``list[dict]``.
|
|
145
145
|
"""
|
|
146
146
|
self._input_context = input_context
|
|
147
|
-
self.
|
|
147
|
+
self._n = n
|
|
148
148
|
self._output_path = output_path
|
|
149
149
|
# Open file handle — set by open() when writing to disk.
|
|
150
150
|
self._handle: IO[str] | None = None
|
|
@@ -181,7 +181,7 @@ class OutputWriter:
|
|
|
181
181
|
originals : list[str]
|
|
182
182
|
The original texts in this batch.
|
|
183
183
|
paraphrases_by_text : list[list[str]]
|
|
184
|
-
One inner list per original text, each containing *
|
|
184
|
+
One inner list per original text, each containing *n*
|
|
185
185
|
paraphrased variants. For example, with 2 styles and 2
|
|
186
186
|
texts: ``[["a_style1", "a_style2"], ["b_style1", "b_style2"]]``.
|
|
187
187
|
Raises
|
|
@@ -197,10 +197,10 @@ class OutputWriter:
|
|
|
197
197
|
)
|
|
198
198
|
|
|
199
199
|
for i, (orig, paras) in enumerate(zip(originals, paraphrases_by_text)):
|
|
200
|
-
if len(paras) != self.
|
|
201
|
-
_log.
|
|
200
|
+
if len(paras) != self._n:
|
|
201
|
+
_log.debug(
|
|
202
202
|
"Expected %d paraphrases for text %d, got %d.",
|
|
203
|
-
self.
|
|
203
|
+
self._n, i, len(paras),
|
|
204
204
|
)
|
|
205
205
|
record = {"original": orig, "paraphrases": paras}
|
|
206
206
|
if self._output_path is None:
|
|
@@ -18,19 +18,19 @@ def reassemble_segments(
|
|
|
18
18
|
:func:`~diversify_text._preprocess.split_sentences`).
|
|
19
19
|
paraphrases_by_segment : list[list[str]]
|
|
20
20
|
Flat list of paraphrases for every segment, shape
|
|
21
|
-
``[total_segments][
|
|
21
|
+
``[total_segments][n]``.
|
|
22
22
|
|
|
23
23
|
Returns
|
|
24
24
|
-------
|
|
25
25
|
list[list[str]]
|
|
26
|
-
Shape ``[n_texts][
|
|
26
|
+
Shape ``[n_texts][n]`` — reassembled paraphrases.
|
|
27
27
|
"""
|
|
28
28
|
result = []
|
|
29
29
|
seg_idx = 0
|
|
30
30
|
for segs in segments_per_text:
|
|
31
31
|
seg_paras = paraphrases_by_segment[seg_idx : seg_idx + len(segs)]
|
|
32
|
-
|
|
33
|
-
result.append([" ".join(sp[i] for sp in seg_paras) for i in range(
|
|
32
|
+
n = len(seg_paras[0])
|
|
33
|
+
result.append([" ".join(sp[i] for sp in seg_paras) for i in range(n)])
|
|
34
34
|
seg_idx += len(segs)
|
|
35
35
|
return result
|
|
36
36
|
|
|
@@ -48,14 +48,14 @@ def postprocess(
|
|
|
48
48
|
Parameters
|
|
49
49
|
----------
|
|
50
50
|
candidate : list[list[str]]
|
|
51
|
-
Raw generation output, shape ``[n_generation_texts][
|
|
51
|
+
Raw generation output, shape ``[n_generation_texts][n]``.
|
|
52
52
|
context : PreprocessContext
|
|
53
53
|
Context returned by :func:`~diversify_text._preprocess.preprocess`.
|
|
54
54
|
|
|
55
55
|
Returns
|
|
56
56
|
-------
|
|
57
57
|
list[list[str]]
|
|
58
|
-
Shape ``[n_texts][
|
|
58
|
+
Shape ``[n_texts][n]`` — one paraphrase per original text
|
|
59
59
|
per style.
|
|
60
60
|
"""
|
|
61
61
|
if context.segments_per_text is not None:
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""Shared internal utilities for diversify."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import contextlib
|
|
6
|
+
import itertools
|
|
7
|
+
import logging
|
|
8
|
+
import sys
|
|
9
|
+
import threading
|
|
10
|
+
import warnings
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def default_device() -> str:
|
|
14
|
+
"""Return the best available torch device (``"cuda"``, ``"mps"``, or ``"cpu"``)."""
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
if torch.cuda.is_available():
|
|
18
|
+
return "cuda"
|
|
19
|
+
if torch.backends.mps.is_available():
|
|
20
|
+
return "mps"
|
|
21
|
+
return "cpu"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@contextlib.contextmanager
|
|
25
|
+
def spinner(message: str = "Loading"):
|
|
26
|
+
"""Display a CLI spinner while a blocking operation runs."""
|
|
27
|
+
frames = itertools.cycle(["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"])
|
|
28
|
+
stop = threading.Event()
|
|
29
|
+
|
|
30
|
+
def _spin() -> None:
|
|
31
|
+
while not stop.is_set():
|
|
32
|
+
frame = next(frames)
|
|
33
|
+
sys.stderr.write(f"\r{frame} {message}")
|
|
34
|
+
sys.stderr.flush()
|
|
35
|
+
stop.wait(0.08)
|
|
36
|
+
sys.stderr.write(f"\r✓ {message}\n")
|
|
37
|
+
sys.stderr.flush()
|
|
38
|
+
|
|
39
|
+
thread = threading.Thread(target=_spin, daemon=True)
|
|
40
|
+
thread.start()
|
|
41
|
+
try:
|
|
42
|
+
yield
|
|
43
|
+
finally:
|
|
44
|
+
stop.set()
|
|
45
|
+
thread.join()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@contextlib.contextmanager
|
|
49
|
+
def suppress_hf_load_noise():
|
|
50
|
+
"""Silence harmless noise emitted when loading HuggingFace models.
|
|
51
|
+
|
|
52
|
+
Covers two sources that Python's warnings module alone cannot reach:
|
|
53
|
+
|
|
54
|
+
- Tied-weights notices from the ``transformers`` logging system.
|
|
55
|
+
- Unexpected-key load reports from the style-embedding model.
|
|
56
|
+
"""
|
|
57
|
+
transformers_logger = logging.getLogger("transformers")
|
|
58
|
+
prev_level = transformers_logger.level
|
|
59
|
+
transformers_logger.setLevel(logging.ERROR)
|
|
60
|
+
try:
|
|
61
|
+
with warnings.catch_warnings():
|
|
62
|
+
warnings.filterwarnings("ignore", message=".*tie.*weight.*")
|
|
63
|
+
yield
|
|
64
|
+
finally:
|
|
65
|
+
transformers_logger.setLevel(prev_level)
|