diversify-text 0.1.2__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {diversify_text-0.1.2 → diversify_text-0.2.0}/PKG-INFO +28 -13
  2. {diversify_text-0.1.2 → diversify_text-0.2.0}/README.md +25 -10
  3. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/__init__.py +2 -0
  4. diversify_text-0.2.0/diversify_text/_cache.py +281 -0
  5. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/_output.py +7 -7
  6. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/_postprocess.py +6 -6
  7. diversify_text-0.2.0/diversify_text/_utils.py +65 -0
  8. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/core.py +126 -53
  9. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/filter/mis.py +30 -12
  10. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/method/base.py +2 -2
  11. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/method/echo.py +2 -2
  12. diversify_text-0.2.0/diversify_text/method/prompting/__init__.py +6 -0
  13. diversify_text-0.2.0/diversify_text/method/prompting/method.py +413 -0
  14. diversify_text-0.2.0/diversify_text/method/prompting/model.py +205 -0
  15. diversify_text-0.2.0/diversify_text/method/prompting/prompts.py +258 -0
  16. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/method/registry.py +2 -0
  17. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/method/tinystyler/__init__.py +1 -1
  18. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/method/tinystyler/method.py +27 -45
  19. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/method/tinystyler/model.py +14 -14
  20. {diversify_text-0.1.2/diversify_text/method/tinystyler → diversify_text-0.2.0/diversify_text}/styles.py +46 -8
  21. {diversify_text-0.1.2 → diversify_text-0.2.0}/pyproject.toml +4 -2
  22. diversify_text-0.1.2/diversify_text/_utils.py +0 -27
  23. {diversify_text-0.1.2 → diversify_text-0.2.0}/.gitignore +0 -0
  24. {diversify_text-0.1.2 → diversify_text-0.2.0}/LICENSE +0 -0
  25. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/_input.py +0 -0
  26. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/_preprocess.py +0 -0
  27. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/filter/__init__.py +0 -0
  28. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/method/__init__.py +0 -0
  29. {diversify_text-0.1.2 → diversify_text-0.2.0}/diversify_text/py.typed +0 -0
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diversify-text
3
- Version: 0.1.2
3
+ Version: 0.2.0
4
4
  Summary: Generate stylistic paraphrases of texts using local transformer models.
5
5
  Project-URL: Homepage, https://github.com/AnnaWegmann/diversify_text
6
6
  Project-URL: Documentation, https://annawegmann.github.io/diversify_text/
7
7
  Project-URL: Repository, https://github.com/AnnaWegmann/diversify_text
8
8
  Project-URL: Issues, https://github.com/AnnaWegmann/diversify_text/issues
9
- Author: Anna Wegmann
9
+ Author: Anna Wegmann, Eduardo Calò, Menan Velayuthan
10
10
  License-Expression: MIT
11
11
  License-File: LICENSE
12
12
  Keywords: augmentation,nlp,paraphrase,style-transfer,text-generation
@@ -26,7 +26,7 @@ Requires-Dist: pysbd>=0.3.4
26
26
  Requires-Dist: sentence-transformers
27
27
  Requires-Dist: sentencepiece
28
28
  Requires-Dist: tiktoken
29
- Requires-Dist: torch
29
+ Requires-Dist: torch>=2.10.0
30
30
  Requires-Dist: tqdm>=4.67.3
31
31
  Requires-Dist: transformers>=5.3.0
32
32
  Description-Content-Type: text/markdown
@@ -46,6 +46,7 @@ pip install diversify-text
46
46
  - [Usage](#usage)
47
47
  - [Single text](#single-text)
48
48
  - [Control number of paraphrases](#control-number-of-paraphrases)
49
+ - [Caching](#caching)
49
50
  - [Using the class directly](#using-the-class-directly)
50
51
  - [List of texts](#list-of-texts)
51
52
  - [Customising the TinyStyler style bank](#customising-the-tinystyler-style-bank)
@@ -84,24 +85,38 @@ results = diversify("The experiment was conducted in a controlled lab setting.")
84
85
  ### Control number of paraphrases
85
86
 
86
87
  ```python
87
- results = diversify("Some text.", n_styles=3)
88
+ results = diversify("Some text.", n=3)
88
89
  ```
89
90
 
90
91
  ```
91
92
  [{"original": "Some text.", "paraphrases": ["...", "...", "..."]}]
92
93
  ```
93
94
 
95
+ ### Caching
96
+
97
+ The `diversify()` function automatically caches loaded models between calls.
98
+ The generation model and the semantic filter are cached independently, so
99
+ toggling `semantic_filter` does not reload the generation model and vice
100
+ versa. Call `clear_cache()` to drop cached models and allow memory to be reclaimed when possible:
101
+
102
+ ```python
103
+ from diversify_text import clear_cache
104
+
105
+ clear_cache()
106
+ ```
107
+
94
108
  ### Using the class directly
95
109
 
96
- Recommended when processing texts across several calls the model is loaded once and reused across calls.
110
+ You can also instantiate a `Diversifier` yourself for full control over the
111
+ model lifecycle:
97
112
 
98
113
  ```python
99
114
  from diversify_text import Diversifier
100
115
 
101
116
  div = Diversifier(device="cuda", methods=["tinystyler"])
102
117
 
103
- batch_1 = div.diversify(texts_1, n_styles=5)
104
- batch_2 = div.diversify(texts_2, n_styles=5)
118
+ batch_1 = div.diversify(texts_1, n=5)
119
+ batch_2 = div.diversify(texts_2, n=5)
105
120
  ```
106
121
 
107
122
  ### List of texts
@@ -130,7 +145,7 @@ A style bank can be a `dict[str, list[str]]` or a `list[list[str]]`:
130
145
 
131
146
  ```python
132
147
  from diversify_text import diversify
133
- from diversify_text.method.tinystyler import DEFAULT_STYLE_BANK
148
+ from diversify_text.styles import DEFAULT_STYLE_BANK
134
149
 
135
150
  custom_bank = {
136
151
  "academic": ["The results demonstrate a statistically significant effect."],
@@ -144,10 +159,10 @@ results = diversify(
144
159
  )
145
160
  ```
146
161
 
147
- `DEFAULT_STYLE_BANK` is exported from `diversify_text.method.tinystyler` so you can build on it:
162
+ `DEFAULT_STYLE_BANK` is exported from `diversify_text.styles` so you can build on it:
148
163
 
149
164
  ```python
150
- from diversify_text.method.tinystyler import DEFAULT_STYLE_BANK
165
+ from diversify_text.styles import DEFAULT_STYLE_BANK
151
166
 
152
167
  extended_bank = {
153
168
  **DEFAULT_STYLE_BANK,
@@ -175,11 +190,11 @@ from diversify_text.method import DiversificationMethod
175
190
  class MyMethod(DiversificationMethod):
176
191
  name = "my_method"
177
192
 
178
- def generate(self, texts, *, n_styles, max_new_tokens, temperature, top_p, **kwargs):
179
- return [[f"{text} :: variant {i}" for i in range(n_styles)] for text in texts]
193
+ def generate(self, texts, *, n, max_new_tokens, temperature, top_p, **kwargs):
194
+ return [[f"{text} :: variant {i}" for i in range(n)] for text in texts]
180
195
 
181
196
 
182
- results = Diversifier(methods=[MyMethod()]).diversify("Hello", n_styles=3)
197
+ results = Diversifier(methods=[MyMethod()]).diversify("Hello", n=3)
183
198
  ```
184
199
 
185
200
  ```
@@ -13,6 +13,7 @@ pip install diversify-text
13
13
  - [Usage](#usage)
14
14
  - [Single text](#single-text)
15
15
  - [Control number of paraphrases](#control-number-of-paraphrases)
16
+ - [Caching](#caching)
16
17
  - [Using the class directly](#using-the-class-directly)
17
18
  - [List of texts](#list-of-texts)
18
19
  - [Customising the TinyStyler style bank](#customising-the-tinystyler-style-bank)
@@ -51,24 +52,38 @@ results = diversify("The experiment was conducted in a controlled lab setting.")
51
52
  ### Control number of paraphrases
52
53
 
53
54
  ```python
54
- results = diversify("Some text.", n_styles=3)
55
+ results = diversify("Some text.", n=3)
55
56
  ```
56
57
 
57
58
  ```
58
59
  [{"original": "Some text.", "paraphrases": ["...", "...", "..."]}]
59
60
  ```
60
61
 
62
+ ### Caching
63
+
64
+ The `diversify()` function automatically caches loaded models between calls.
65
+ The generation model and the semantic filter are cached independently, so
66
+ toggling `semantic_filter` does not reload the generation model and vice
67
+ versa. Call `clear_cache()` to drop cached models and allow memory to be reclaimed when possible:
68
+
69
+ ```python
70
+ from diversify_text import clear_cache
71
+
72
+ clear_cache()
73
+ ```
74
+
61
75
  ### Using the class directly
62
76
 
63
- Recommended when processing texts across several calls the model is loaded once and reused across calls.
77
+ You can also instantiate a `Diversifier` yourself for full control over the
78
+ model lifecycle:
64
79
 
65
80
  ```python
66
81
  from diversify_text import Diversifier
67
82
 
68
83
  div = Diversifier(device="cuda", methods=["tinystyler"])
69
84
 
70
- batch_1 = div.diversify(texts_1, n_styles=5)
71
- batch_2 = div.diversify(texts_2, n_styles=5)
85
+ batch_1 = div.diversify(texts_1, n=5)
86
+ batch_2 = div.diversify(texts_2, n=5)
72
87
  ```
73
88
 
74
89
  ### List of texts
@@ -97,7 +112,7 @@ A style bank can be a `dict[str, list[str]]` or a `list[list[str]]`:
97
112
 
98
113
  ```python
99
114
  from diversify_text import diversify
100
- from diversify_text.method.tinystyler import DEFAULT_STYLE_BANK
115
+ from diversify_text.styles import DEFAULT_STYLE_BANK
101
116
 
102
117
  custom_bank = {
103
118
  "academic": ["The results demonstrate a statistically significant effect."],
@@ -111,10 +126,10 @@ results = diversify(
111
126
  )
112
127
  ```
113
128
 
114
- `DEFAULT_STYLE_BANK` is exported from `diversify_text.method.tinystyler` so you can build on it:
129
+ `DEFAULT_STYLE_BANK` is exported from `diversify_text.styles` so you can build on it:
115
130
 
116
131
  ```python
117
- from diversify_text.method.tinystyler import DEFAULT_STYLE_BANK
132
+ from diversify_text.styles import DEFAULT_STYLE_BANK
118
133
 
119
134
  extended_bank = {
120
135
  **DEFAULT_STYLE_BANK,
@@ -142,11 +157,11 @@ from diversify_text.method import DiversificationMethod
142
157
  class MyMethod(DiversificationMethod):
143
158
  name = "my_method"
144
159
 
145
- def generate(self, texts, *, n_styles, max_new_tokens, temperature, top_p, **kwargs):
146
- return [[f"{text} :: variant {i}" for i in range(n_styles)] for text in texts]
160
+ def generate(self, texts, *, n, max_new_tokens, temperature, top_p, **kwargs):
161
+ return [[f"{text} :: variant {i}" for i in range(n)] for text in texts]
147
162
 
148
163
 
149
- results = Diversifier(methods=[MyMethod()]).diversify("Hello", n_styles=3)
164
+ results = Diversifier(methods=[MyMethod()]).diversify("Hello", n=3)
150
165
  ```
151
166
 
152
167
  ```
@@ -2,6 +2,7 @@
2
2
 
3
3
  import logging
4
4
 
5
+ from diversify_text._cache import clear_cache
5
6
  from diversify_text.core import (
6
7
  Diversifier,
7
8
  diversify,
@@ -9,6 +10,7 @@ from diversify_text.core import (
9
10
 
10
11
  __all__ = [
11
12
  "Diversifier",
13
+ "clear_cache",
12
14
  "diversify",
13
15
  ]
14
16
 
@@ -0,0 +1,281 @@
1
+ """Per-model caching for the :func:`~diversify_text.core.diversify` convenience function.
2
+
3
+ Keeps the generation method(s) and the MIS filter in independent
4
+ module-level caches so that toggling ``semantic_filter`` does not
5
+ reload the generation model, and switching methods does not reload the
6
+ MIS model.
7
+
8
+ Each generation method is cached individually so that adding, removing,
9
+ or reordering methods only (re)loads the ones whose configuration
10
+ actually changed.
11
+
12
+ Not thread-safe. Intended for single-threaded use in scripts and
13
+ notebooks. For multi-threaded applications, use :class:`Diversifier`
14
+ directly with your own instance management.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import inspect
20
+ from collections.abc import Mapping, Sequence
21
+ from functools import lru_cache
22
+ from typing import Any
23
+
24
+ from diversify_text._utils import default_device
25
+ from diversify_text.filter.mis import MISFilter, _DEFAULT_MIN_SCORE, _DEFAULT_N_CANDIDATES
26
+ from diversify_text.method import DEFAULT_METHOD_REGISTRY, DiversificationMethod
27
+
28
+
29
+ # kwargs that affect model construction and should invalidate the cache.
30
+ # Per-call kwargs (styles, prompts, n_style_examples, etc.) are excluded.
31
+ _CACHE_KWARGS = {"model", "device", "precision"}
32
+
33
+
34
+ # ------------------------------------------------------------------
35
+ # Generation method cache (dict-based, one entry per method)
36
+ # ------------------------------------------------------------------
37
+
38
+ _METHOD_CACHE: dict[tuple, DiversificationMethod] = {}
39
+
40
+
41
+ def _resolve_cache_kwargs(
42
+ method_name: str,
43
+ device: str,
44
+ method_kwargs: Mapping[str, dict[str, Any]] | None = None,
45
+ ) -> dict[str, Any]:
46
+ """Resolve the full set of cache-relevant kwargs for a method.
47
+
48
+ Merges caller-provided kwargs with the constructor's own defaults
49
+ (discovered via ``inspect.signature``) so that the cache key is
50
+ the same whether the caller explicitly passes a default value or
51
+ omits it. Only kwargs in :data:`_CACHE_KWARGS` are included.
52
+
53
+ For example, ``PromptingMethod.__init__`` has
54
+ ``model="HuggingFaceTB/SmolLM3-3B"`` as a default.
55
+ These two calls should hit the same cache entry::
56
+
57
+ # Omit model — default is filled in from the signature.
58
+ get_methods(device=None, methods=["prompting"])
59
+
60
+ # Explicitly pass the same default value.
61
+ get_methods(device=None, methods=["prompting"],
62
+ method_kwargs={"prompting": {"model": "HuggingFaceTB/SmolLM3-3B"}})
63
+
64
+ Without this function the first call would produce the key
65
+ ``("prompting", (("device", "cpu"),))`` (no model) and the second
66
+ ``("prompting", (("device", "cpu"), ("model", "HuggingFaceTB/..."),))``
67
+ — different keys, two copies of the same model loaded.
68
+
69
+ Parameters
70
+ ----------
71
+ method_name : str
72
+ Registry name of the method (e.g. ``"tinystyler"``).
73
+ device : str
74
+ Torch device string (already resolved, never ``None``).
75
+ method_kwargs : mapping, optional
76
+ Per-method keyword arguments keyed by method name. Only the
77
+ entry for *method_name* is inspected.
78
+
79
+ Returns
80
+ -------
81
+ dict[str, Any]
82
+ The full set of cache-relevant kwargs, e.g.
83
+ ``{"device": "cpu", "model": "HuggingFaceTB/SmolLM3-3B", "precision": "auto"}``.
84
+ """
85
+ # Start with device (always present).
86
+ resolved: dict[str, Any] = {"device": device}
87
+
88
+ # Fill in constructor defaults from the method class signature.
89
+ method_class = DEFAULT_METHOD_REGISTRY.get(method_name)
90
+ signature = inspect.signature(method_class)
91
+ for param_name, param in signature.parameters.items():
92
+ # inspect.Parameter.empty is a sentinel meaning "no default value."
93
+ # We skip those — only fill in defaults that actually exist.
94
+ if (
95
+ param_name in _CACHE_KWARGS
96
+ and param_name not in resolved
97
+ and param.default is not inspect.Parameter.empty
98
+ ):
99
+ resolved[param_name] = param.default
100
+
101
+ # Override defaults with caller-provided kwargs.
102
+ if method_kwargs and (method_name in method_kwargs):
103
+ for k, v in method_kwargs[method_name].items():
104
+ if k in _CACHE_KWARGS:
105
+ resolved[k] = v
106
+
107
+ return resolved
108
+
109
+
110
+ def _single_METHOD_CACHE_key(
111
+ method_name: str,
112
+ device: str,
113
+ method_kwargs: Mapping[str, dict[str, Any]] | None = None,
114
+ ) -> tuple:
115
+ """Build a hashable key for a single generation method.
116
+
117
+ The key uniquely identifies a loaded model instance. It includes
118
+ only constructor-level kwargs (``model``, ``device``, ``precision``)
119
+ that determine *which* model gets loaded. Per-call kwargs like
120
+ ``styles`` or ``prompts`` are excluded — changing those should reuse
121
+ the same model, not trigger an expensive reload.
122
+
123
+ Constructor defaults are resolved via ``inspect.signature`` so that
124
+ explicitly passing a default value produces the same cache key as
125
+ omitting it entirely.
126
+
127
+ Parameters
128
+ ----------
129
+ method_name : str
130
+ Registry name of the method (e.g. ``"tinystyler"``).
131
+ device : str
132
+ Torch device string (already resolved, never ``None``).
133
+ method_kwargs : mapping, optional
134
+ Per-method keyword arguments keyed by method name. Only the
135
+ entry for *method_name* is inspected, and only cache
136
+ kwargs within that entry are included in the key.
137
+
138
+ Returns
139
+ -------
140
+ tuple
141
+ A hashable key, e.g.
142
+ ``("prompting", "cpu", (("model", "default-model"), ("precision", "auto")))``.
143
+ """
144
+ resolved = _resolve_cache_kwargs(method_name, device, method_kwargs)
145
+ constructor_kwargs = tuple(sorted(resolved.items()))
146
+ return method_name, constructor_kwargs
147
+
148
+
149
+ def get_methods(
150
+ device: str | None,
151
+ methods: Sequence[str | DiversificationMethod] | None,
152
+ method_kwargs: Mapping[str, dict[str, Any]] | None = None,
153
+ ) -> list[DiversificationMethod]:
154
+ """Return cached generation methods, resolving only on config change.
155
+
156
+ Iterates the requested *methods* list and resolves each one
157
+ individually against a module-level dict cache. On a cache miss
158
+ the method is instantiated via the registry (expensive — may load
159
+ a model); on a hit the existing instance is reused.
160
+
161
+ Methods can be specified as strings (looked up in the registry) or
162
+ as pre-built :class:`DiversificationMethod` instances (passed
163
+ through as-is without caching, since they're already instantiated).
164
+ You can mix both in one call, e.g.
165
+ ``methods=["tinystyler", my_custom_method]``.
166
+
167
+ Because each method is cached independently, adding or removing a
168
+ method from the list only loads the new ones — already-cached
169
+ methods are not affected.
170
+
171
+ Parameters
172
+ ----------
173
+ device : str or None
174
+ Torch device. ``None`` resolves to :func:`default_device`.
175
+ methods : sequence of str or DiversificationMethod, optional
176
+ Method names and/or pre-built instances. Defaults to
177
+ ``["tinystyler"]``.
178
+ method_kwargs : mapping, optional
179
+ Per-method keyword arguments keyed by method name, e.g.
180
+ ``{"prompting": {"model": "gpt2"}}``. Constructor kwargs
181
+ (``model``, ``device``, ``precision``) affect the cache key;
182
+ per-call kwargs (``styles``, ``prompts``) do not.
183
+
184
+ Returns
185
+ -------
186
+ list[DiversificationMethod]
187
+ Resolved method instances in the same order as *methods*.
188
+ """
189
+ device = device or default_device()
190
+ if methods is None:
191
+ methods = ["tinystyler"]
192
+
193
+ result: list[DiversificationMethod] = []
194
+ for method in methods:
195
+ if isinstance(method, DiversificationMethod):
196
+ result.append(method)
197
+ elif isinstance(method, str):
198
+ key = _single_METHOD_CACHE_key(method, device, method_kwargs)
199
+ if key not in _METHOD_CACHE: # cache miss → resolve and store
200
+ resolve_kwargs: dict[str, Any] = {"device": device}
201
+ if method_kwargs and (method in method_kwargs):
202
+ resolve_kwargs.update(method_kwargs[method])
203
+ _METHOD_CACHE[key] = DEFAULT_METHOD_REGISTRY.resolve(
204
+ [method], **resolve_kwargs
205
+ )[0]
206
+ result.append(_METHOD_CACHE[key])
207
+ else:
208
+ raise TypeError(
209
+ "method must be str or DiversificationMethod instance."
210
+ )
211
+
212
+ if not result:
213
+ raise ValueError("At least one method is required.")
214
+ return result
215
+
216
+
217
+ # ------------------------------------------------------------------
218
+ # MIS filter cache (lru_cache for expensive model load, thin wrapper
219
+ # for cheap per-call settings like min_score and n_candidates)
220
+ # ------------------------------------------------------------------
221
+
222
+ @lru_cache(maxsize=1)
223
+ def _load_mis_filter(device: str) -> MISFilter:
224
+ """Load the MIS filter model (expensive).
225
+
226
+ This is the expensive part — loading the model weights. The
227
+ ``lru_cache`` decorator ensures this only runs once per last used device
228
+ string. Cheap per-call settings (``min_score``, ``n_candidates``)
229
+ are applied separately in :func:`get_cached_mis_filter`.
230
+ """
231
+ return MISFilter(device=device)
232
+
233
+
234
+ def get_cached_mis_filter(
235
+ device: str | None,
236
+ **filter_kwargs: Any,
237
+ ) -> MISFilter:
238
+ """Return cached MIS filter, reloading only when *device* changes.
239
+
240
+ Thin wrapper around :func:`_load_mis_filter`. The model load is
241
+ cached (expensive); this function just applies cheap per-call
242
+ threshold settings on the existing instance. Changing
243
+ ``min_score`` or ``n_candidates`` between calls does not trigger a
244
+ model reload — only a device change does.
245
+
246
+ Parameters
247
+ ----------
248
+ device : str or None
249
+ Torch device. ``None`` resolves to :func:`default_device`.
250
+ **filter_kwargs
251
+ ``min_score`` and ``n_candidates``. Missing keys reset to
252
+ their defaults so that omitting a kwarg doesn't leave a stale
253
+ value from a previous call.
254
+ """
255
+ device = device or default_device()
256
+ mis_filter = _load_mis_filter(device)
257
+ mis_filter.min_score = filter_kwargs.get("min_score", _DEFAULT_MIN_SCORE)
258
+ mis_filter.n_candidates = filter_kwargs.get("n_candidates", _DEFAULT_N_CANDIDATES)
259
+ return mis_filter
260
+
261
+
262
+ # ------------------------------------------------------------------
263
+ # Cache management
264
+ # ------------------------------------------------------------------
265
+
266
+ def clear_cache() -> None:
267
+ """Drop references to all cached models so their memory can be reclaimed when possible.
268
+
269
+ Clears both the generation method dict cache and the ``lru_cache``
270
+ backing the MIS filter. After calling this, the next
271
+ :func:`get_methods` or :func:`get_cached_mis_filter` call will
272
+ load models from scratch.
273
+
274
+ This clears Python-level references but does not guarantee immediate
275
+ GPU/CPU memory release (e.g., allocator pools may retain reserved
276
+ memory).
277
+ """
278
+ global _METHOD_CACHE
279
+
280
+ _METHOD_CACHE = {}
281
+ _load_mis_filter.cache_clear()
@@ -128,7 +128,7 @@ class OutputWriter:
128
128
  def __init__(
129
129
  self,
130
130
  input_context: InputContext,
131
- n_styles: int,
131
+ n: int,
132
132
  output_path: Path | None,
133
133
  ) -> None:
134
134
  """Initialize the writer.
@@ -137,14 +137,14 @@ class OutputWriter:
137
137
  ----------
138
138
  input_context : InputContext
139
139
  Metadata about the input source (kind, path, etc.).
140
- n_styles : int
140
+ n : int
141
141
  Number of paraphrase styles requested per text.
142
142
  output_path : Path or None
143
143
  Where to write results on disk. ``None`` means results
144
144
  are kept in memory and returned as ``list[dict]``.
145
145
  """
146
146
  self._input_context = input_context
147
- self._n_styles = n_styles
147
+ self._n = n
148
148
  self._output_path = output_path
149
149
  # Open file handle — set by open() when writing to disk.
150
150
  self._handle: IO[str] | None = None
@@ -181,7 +181,7 @@ class OutputWriter:
181
181
  originals : list[str]
182
182
  The original texts in this batch.
183
183
  paraphrases_by_text : list[list[str]]
184
- One inner list per original text, each containing *n_styles*
184
+ One inner list per original text, each containing *n*
185
185
  paraphrased variants. For example, with 2 styles and 2
186
186
  texts: ``[["a_style1", "a_style2"], ["b_style1", "b_style2"]]``.
187
187
  Raises
@@ -197,10 +197,10 @@ class OutputWriter:
197
197
  )
198
198
 
199
199
  for i, (orig, paras) in enumerate(zip(originals, paraphrases_by_text)):
200
- if len(paras) != self._n_styles:
201
- _log.warning(
200
+ if len(paras) != self._n:
201
+ _log.debug(
202
202
  "Expected %d paraphrases for text %d, got %d.",
203
- self._n_styles, i, len(paras),
203
+ self._n, i, len(paras),
204
204
  )
205
205
  record = {"original": orig, "paraphrases": paras}
206
206
  if self._output_path is None:
@@ -18,19 +18,19 @@ def reassemble_segments(
18
18
  :func:`~diversify_text._preprocess.split_sentences`).
19
19
  paraphrases_by_segment : list[list[str]]
20
20
  Flat list of paraphrases for every segment, shape
21
- ``[total_segments][n_styles]``.
21
+ ``[total_segments][n]``.
22
22
 
23
23
  Returns
24
24
  -------
25
25
  list[list[str]]
26
- Shape ``[n_texts][n_styles]`` — reassembled paraphrases.
26
+ Shape ``[n_texts][n]`` — reassembled paraphrases.
27
27
  """
28
28
  result = []
29
29
  seg_idx = 0
30
30
  for segs in segments_per_text:
31
31
  seg_paras = paraphrases_by_segment[seg_idx : seg_idx + len(segs)]
32
- n_styles = len(seg_paras[0])
33
- result.append([" ".join(sp[i] for sp in seg_paras) for i in range(n_styles)])
32
+ n = len(seg_paras[0])
33
+ result.append([" ".join(sp[i] for sp in seg_paras) for i in range(n)])
34
34
  seg_idx += len(segs)
35
35
  return result
36
36
 
@@ -48,14 +48,14 @@ def postprocess(
48
48
  Parameters
49
49
  ----------
50
50
  candidate : list[list[str]]
51
- Raw generation output, shape ``[n_generation_texts][n_styles]``.
51
+ Raw generation output, shape ``[n_generation_texts][n]``.
52
52
  context : PreprocessContext
53
53
  Context returned by :func:`~diversify_text._preprocess.preprocess`.
54
54
 
55
55
  Returns
56
56
  -------
57
57
  list[list[str]]
58
- Shape ``[n_texts][n_styles]`` — one paraphrase per original text
58
+ Shape ``[n_texts][n]`` — one paraphrase per original text
59
59
  per style.
60
60
  """
61
61
  if context.segments_per_text is not None:
@@ -0,0 +1,65 @@
1
+ """Shared internal utilities for diversify."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import contextlib
6
+ import itertools
7
+ import logging
8
+ import sys
9
+ import threading
10
+ import warnings
11
+
12
+
13
+ def default_device() -> str:
14
+ """Return the best available torch device (``"cuda"``, ``"mps"``, or ``"cpu"``)."""
15
+ import torch
16
+
17
+ if torch.cuda.is_available():
18
+ return "cuda"
19
+ if torch.backends.mps.is_available():
20
+ return "mps"
21
+ return "cpu"
22
+
23
+
24
+ @contextlib.contextmanager
25
+ def spinner(message: str = "Loading"):
26
+ """Display a CLI spinner while a blocking operation runs."""
27
+ frames = itertools.cycle(["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"])
28
+ stop = threading.Event()
29
+
30
+ def _spin() -> None:
31
+ while not stop.is_set():
32
+ frame = next(frames)
33
+ sys.stderr.write(f"\r{frame} {message}")
34
+ sys.stderr.flush()
35
+ stop.wait(0.08)
36
+ sys.stderr.write(f"\r✓ {message}\n")
37
+ sys.stderr.flush()
38
+
39
+ thread = threading.Thread(target=_spin, daemon=True)
40
+ thread.start()
41
+ try:
42
+ yield
43
+ finally:
44
+ stop.set()
45
+ thread.join()
46
+
47
+
48
+ @contextlib.contextmanager
49
+ def suppress_hf_load_noise():
50
+ """Silence harmless noise emitted when loading HuggingFace models.
51
+
52
+ Covers two sources that Python's warnings module alone cannot reach:
53
+
54
+ - Tied-weights notices from the ``transformers`` logging system.
55
+ - Unexpected-key load reports from the style-embedding model.
56
+ """
57
+ transformers_logger = logging.getLogger("transformers")
58
+ prev_level = transformers_logger.level
59
+ transformers_logger.setLevel(logging.ERROR)
60
+ try:
61
+ with warnings.catch_warnings():
62
+ warnings.filterwarnings("ignore", message=".*tie.*weight.*")
63
+ yield
64
+ finally:
65
+ transformers_logger.setLevel(prev_level)