diversify-text 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diversify_text/__init__.py +24 -0
- diversify_text/_input.py +267 -0
- diversify_text/_output.py +234 -0
- diversify_text/_postprocess.py +64 -0
- diversify_text/_preprocess.py +76 -0
- diversify_text/_utils.py +27 -0
- diversify_text/core.py +335 -0
- diversify_text/filter/__init__.py +5 -0
- diversify_text/filter/mis.py +272 -0
- diversify_text/method/__init__.py +13 -0
- diversify_text/method/base.py +35 -0
- diversify_text/method/echo.py +25 -0
- diversify_text/method/registry.py +109 -0
- diversify_text/method/tinystyler/__init__.py +6 -0
- diversify_text/method/tinystyler/method.py +164 -0
- diversify_text/method/tinystyler/model.py +113 -0
- diversify_text/method/tinystyler/styles.py +359 -0
- diversify_text/py.typed +0 -0
- diversify_text-0.1.1.dist-info/METADATA +272 -0
- diversify_text-0.1.1.dist-info/RECORD +22 -0
- diversify_text-0.1.1.dist-info/WHEEL +4 -0
- diversify_text-0.1.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""diversify-text -- generate stylistic paraphrases of texts."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from diversify_text.core import (
|
|
6
|
+
Diversifier,
|
|
7
|
+
diversify,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Diversifier",
|
|
12
|
+
"diversify",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
# Configure a clean handler for the diversify logger so INFO/WARNING messages
|
|
16
|
+
# are visible without requiring the user to set up logging themselves.
|
|
17
|
+
_logger = logging.getLogger("diversify_text")
|
|
18
|
+
_logger.setLevel(logging.INFO)
|
|
19
|
+
_handler = logging.StreamHandler()
|
|
20
|
+
_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
|
|
21
|
+
_logger.addHandler(_handler)
|
|
22
|
+
# Prevent messages from bubbling up to the root logger (avoids duplicate output
|
|
23
|
+
# if the user has already configured logging globally).
|
|
24
|
+
_logger.propagate = False
|
diversify_text/_input.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
"""Input resolution for diversify.
|
|
2
|
+
|
|
3
|
+
Converts the many input forms users can provide (single string, list,
|
|
4
|
+
generator, CSV/TSV/TXT file path) into a uniform ``Iterator[str]`` plus
|
|
5
|
+
an :class:`InputContext` that describes the source.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import csv
|
|
11
|
+
import logging
|
|
12
|
+
from collections.abc import Iterable, Iterator
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from enum import Enum, auto
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Union
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# ------------------------------------------------------------------
|
|
20
|
+
# Type alias
|
|
21
|
+
# ------------------------------------------------------------------
|
|
22
|
+
|
|
23
|
+
TextInput = Union[str, "list[str]", "Iterable[str]"]
|
|
24
|
+
|
|
25
|
+
_log = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# ------------------------------------------------------------------
|
|
29
|
+
# Input kind & context
|
|
30
|
+
# ------------------------------------------------------------------
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class InputKind(Enum):
|
|
34
|
+
"""Discriminator for how the user provided input."""
|
|
35
|
+
|
|
36
|
+
SINGLE_STR = auto()
|
|
37
|
+
LIST = auto()
|
|
38
|
+
ITERABLE = auto()
|
|
39
|
+
FILE_CSV = auto()
|
|
40
|
+
FILE_TSV = auto()
|
|
41
|
+
FILE_TXT = auto()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass(frozen=True)
|
|
45
|
+
class InputContext:
|
|
46
|
+
"""Read-only metadata about the resolved input source."""
|
|
47
|
+
|
|
48
|
+
kind: InputKind
|
|
49
|
+
input_path: Path | None = None
|
|
50
|
+
text_column: str | None = None
|
|
51
|
+
total: int | None = None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# ------------------------------------------------------------------
|
|
55
|
+
# Public API
|
|
56
|
+
# ------------------------------------------------------------------
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def resolve_input(
|
|
60
|
+
texts: TextInput,
|
|
61
|
+
text_column: str = "text",
|
|
62
|
+
) -> tuple[Iterator[str], InputContext]:
|
|
63
|
+
"""Convert any supported input into a lazy ``Iterator[str]`` plus metadata.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
texts : str | list[str] | Iterable[str]
|
|
68
|
+
A single text, a list of texts, a generator / iterable of texts,
|
|
69
|
+
or a path to a ``.csv``, ``.tsv``, or ``.txt`` file.
|
|
70
|
+
text_column : str
|
|
71
|
+
Column name to extract when *texts* points to a CSV/TSV file.
|
|
72
|
+
|
|
73
|
+
Returns
|
|
74
|
+
-------
|
|
75
|
+
(Iterator[str], InputContext)
|
|
76
|
+
"""
|
|
77
|
+
# --- str: could be a file path or a single text ---
|
|
78
|
+
if isinstance(texts, str):
|
|
79
|
+
path = Path(texts)
|
|
80
|
+
suffix = path.suffix.lower()
|
|
81
|
+
|
|
82
|
+
if suffix == ".csv" and path.is_file():
|
|
83
|
+
_validate_csv_header(path, text_column, delimiter=",")
|
|
84
|
+
total = _count_file_lines(path) - 1 # subtract header row
|
|
85
|
+
return _iter_csv(path, text_column, delimiter=","), InputContext(
|
|
86
|
+
kind=InputKind.FILE_CSV,
|
|
87
|
+
input_path=path,
|
|
88
|
+
text_column=text_column,
|
|
89
|
+
total=total,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if suffix == ".tsv" and path.is_file():
|
|
93
|
+
_validate_csv_header(path, text_column, delimiter="\t")
|
|
94
|
+
total = _count_file_lines(path) - 1 # subtract header row
|
|
95
|
+
return _iter_csv(path, text_column, delimiter="\t"), InputContext(
|
|
96
|
+
kind=InputKind.FILE_TSV,
|
|
97
|
+
input_path=path,
|
|
98
|
+
text_column=text_column,
|
|
99
|
+
total=total,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if suffix == ".txt" and path.is_file():
|
|
103
|
+
total = _count_nonempty_lines(path)
|
|
104
|
+
_log.warning(
|
|
105
|
+
"TXT file input: each line is treated as a separate text "
|
|
106
|
+
"to diversify. Newlines are never part of a parsed text. "
|
|
107
|
+
"(%d non-empty lines found in %s)",
|
|
108
|
+
total,
|
|
109
|
+
path,
|
|
110
|
+
)
|
|
111
|
+
return _iter_txt_lines(path), InputContext(
|
|
112
|
+
kind=InputKind.FILE_TXT,
|
|
113
|
+
input_path=path,
|
|
114
|
+
total=total,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Not a recognized file — treat as a single text.
|
|
118
|
+
return iter([texts]), InputContext(kind=InputKind.SINGLE_STR, total=1)
|
|
119
|
+
|
|
120
|
+
# --- list[str] ---
|
|
121
|
+
if isinstance(texts, list):
|
|
122
|
+
return iter(texts), InputContext(kind=InputKind.LIST, total=len(texts))
|
|
123
|
+
|
|
124
|
+
# --- generic Iterable[str] (generators, file handles, …) ---
|
|
125
|
+
if isinstance(texts, Iterable):
|
|
126
|
+
return iter(texts), InputContext(kind=InputKind.ITERABLE, total=None)
|
|
127
|
+
|
|
128
|
+
raise TypeError(
|
|
129
|
+
f"Unsupported input type {type(texts).__name__}. "
|
|
130
|
+
"Expected str, list[str], or Iterable[str]."
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# ------------------------------------------------------------------
|
|
135
|
+
# Validation & counting (private, cheap first-pass helpers)
|
|
136
|
+
# ------------------------------------------------------------------
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _validate_csv_header(path: Path, text_column: str, delimiter: str) -> None:
|
|
140
|
+
"""Check that *text_column* exists in the CSV/TSV header.
|
|
141
|
+
|
|
142
|
+
Opens the file, reads only the header row, then closes it.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
path : Path
|
|
147
|
+
Path to the CSV or TSV file.
|
|
148
|
+
text_column : str
|
|
149
|
+
Expected column name.
|
|
150
|
+
delimiter : str
|
|
151
|
+
Field separator — ``","`` for CSV, ``"\\t"`` for TSV.
|
|
152
|
+
|
|
153
|
+
Raises
|
|
154
|
+
------
|
|
155
|
+
ValueError
|
|
156
|
+
If *text_column* is not found among the file's header fields.
|
|
157
|
+
"""
|
|
158
|
+
with open(path, newline="", encoding="utf-8") as f:
|
|
159
|
+
reader = csv.DictReader(f, delimiter=delimiter)
|
|
160
|
+
if reader.fieldnames is None or text_column not in reader.fieldnames:
|
|
161
|
+
available = ", ".join(reader.fieldnames or [])
|
|
162
|
+
raise ValueError(
|
|
163
|
+
f"Column '{text_column}' not found in {path}. "
|
|
164
|
+
f"Available: {available}"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _count_file_lines(path: Path) -> int:
|
|
169
|
+
"""Count the total number of lines in a file.
|
|
170
|
+
|
|
171
|
+
This is a cheap pass that only reads raw lines — no CSV parsing.
|
|
172
|
+
For CSV/TSV files, subtract 1 for the header to get the data-row
|
|
173
|
+
count. The count may slightly overestimate if the file contains
|
|
174
|
+
multi-line quoted CSV fields, but that is acceptable for a progress
|
|
175
|
+
bar.
|
|
176
|
+
|
|
177
|
+
Parameters
|
|
178
|
+
----------
|
|
179
|
+
path : Path
|
|
180
|
+
Path to the file.
|
|
181
|
+
|
|
182
|
+
Returns
|
|
183
|
+
-------
|
|
184
|
+
int
|
|
185
|
+
Total number of lines.
|
|
186
|
+
"""
|
|
187
|
+
with open(path, encoding="utf-8") as f:
|
|
188
|
+
return sum(1 for _ in f)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _count_nonempty_lines(path: Path) -> int:
|
|
192
|
+
"""Count non-empty lines in a file (skips blank / whitespace-only).
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
path : Path
|
|
197
|
+
Path to the file.
|
|
198
|
+
|
|
199
|
+
Returns
|
|
200
|
+
-------
|
|
201
|
+
int
|
|
202
|
+
Number of non-empty lines.
|
|
203
|
+
"""
|
|
204
|
+
with open(path, encoding="utf-8") as f:
|
|
205
|
+
return sum(1 for line in f if line.strip())
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
# ------------------------------------------------------------------
|
|
209
|
+
# Lazy file iterators (private)
|
|
210
|
+
# ------------------------------------------------------------------
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _iter_csv(path: Path, text_column: str, delimiter: str) -> Iterator[str]:
|
|
214
|
+
"""Lazily yield *text_column* values from a CSV/TSV file.
|
|
215
|
+
|
|
216
|
+
Only one row is held in memory at a time. The file handle is
|
|
217
|
+
closed automatically when the generator is exhausted or garbage
|
|
218
|
+
collected.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
path : Path
|
|
223
|
+
Path to the CSV or TSV file.
|
|
224
|
+
text_column : str
|
|
225
|
+
Name of the column that contains the texts to diversify.
|
|
226
|
+
delimiter : str
|
|
227
|
+
Field separator — ``","`` for CSV, ``"\\t"`` for TSV.
|
|
228
|
+
|
|
229
|
+
Yields
|
|
230
|
+
------
|
|
231
|
+
str
|
|
232
|
+
The text value from each row.
|
|
233
|
+
"""
|
|
234
|
+
f = open(path, newline="", encoding="utf-8")
|
|
235
|
+
try:
|
|
236
|
+
reader = csv.DictReader(f, delimiter=delimiter)
|
|
237
|
+
for row in reader:
|
|
238
|
+
yield row.get(text_column) or ""
|
|
239
|
+
finally:
|
|
240
|
+
f.close()
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _iter_txt_lines(path: Path) -> Iterator[str]:
|
|
244
|
+
"""Lazily yield non-empty, stripped lines from a ``.txt`` file.
|
|
245
|
+
|
|
246
|
+
Blank lines and whitespace-only lines are skipped. Only one line
|
|
247
|
+
is held in memory at a time. The file handle is closed
|
|
248
|
+
automatically when the generator is exhausted or garbage collected.
|
|
249
|
+
|
|
250
|
+
Parameters
|
|
251
|
+
----------
|
|
252
|
+
path : Path
|
|
253
|
+
Path to the ``.txt`` file.
|
|
254
|
+
|
|
255
|
+
Yields
|
|
256
|
+
------
|
|
257
|
+
str
|
|
258
|
+
Each non-empty line, with leading/trailing whitespace stripped.
|
|
259
|
+
"""
|
|
260
|
+
f = open(path, encoding="utf-8")
|
|
261
|
+
try:
|
|
262
|
+
for line in f:
|
|
263
|
+
stripped = line.strip()
|
|
264
|
+
if stripped:
|
|
265
|
+
yield stripped
|
|
266
|
+
finally:
|
|
267
|
+
f.close()
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""Output path resolution and incremental writing for diversify.
|
|
2
|
+
|
|
3
|
+
Decides *where* results go (in-memory vs. disk) and writes them in the
|
|
4
|
+
appropriate format (Python list or JSONL).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import IO, Any, Union
|
|
13
|
+
|
|
14
|
+
from diversify_text._input import InputContext, InputKind
|
|
15
|
+
|
|
16
|
+
_log = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# ------------------------------------------------------------------
|
|
20
|
+
# Type alias
|
|
21
|
+
# ------------------------------------------------------------------
|
|
22
|
+
|
|
23
|
+
DiversifyOutput = Union[list[dict], Path]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# ------------------------------------------------------------------
|
|
27
|
+
# Output path resolution
|
|
28
|
+
# ------------------------------------------------------------------
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def resolve_output_path(
|
|
32
|
+
input_context: InputContext,
|
|
33
|
+
output_dir: str | Path | None = None,
|
|
34
|
+
output_name: str | None = None,
|
|
35
|
+
) -> Path | None:
|
|
36
|
+
"""Determine where output should be written, or ``None`` for in-memory.
|
|
37
|
+
|
|
38
|
+
The user controls *where* (directory) and *what name* (stem) to use,
|
|
39
|
+
but the **extension is always ``.jsonl``**.
|
|
40
|
+
|
|
41
|
+
Directory defaults (when *output_dir* is ``None``):
|
|
42
|
+
|
|
43
|
+
* ``SINGLE_STR`` / ``LIST`` → ``None`` (keep in memory, return as
|
|
44
|
+
``list[dict]``). These are small, known-size inputs from Python
|
|
45
|
+
code, so the caller typically wants results as Python objects.
|
|
46
|
+
* ``ITERABLE`` → current working directory.
|
|
47
|
+
* ``FILE_CSV`` / ``FILE_TSV`` / ``FILE_TXT`` → same directory as
|
|
48
|
+
the input file.
|
|
49
|
+
|
|
50
|
+
If *output_dir* is provided for ``SINGLE_STR`` / ``LIST``, results
|
|
51
|
+
are written to disk instead of being returned in memory.
|
|
52
|
+
|
|
53
|
+
Name defaults (when *output_name* is ``None``):
|
|
54
|
+
|
|
55
|
+
* ``FILE_CSV`` / ``FILE_TSV`` → ``<input_stem>_diversified``
|
|
56
|
+
* ``FILE_TXT`` → ``<input_stem>``
|
|
57
|
+
* Everything else → ``diversified_output``
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
input_context : InputContext
|
|
62
|
+
Metadata produced by :func:`resolve_input`.
|
|
63
|
+
output_dir : str, Path, or None
|
|
64
|
+
Directory to write output files into.
|
|
65
|
+
output_name : str or None
|
|
66
|
+
Base filename (without extension). The correct extension is
|
|
67
|
+
appended automatically. If the name already contains an
|
|
68
|
+
extension it is **not** stripped — the correct extension is
|
|
69
|
+
appended after it — unless it already ends with ``.jsonl``.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
Path or None
|
|
74
|
+
``None`` means in-memory mode; otherwise the path to write to.
|
|
75
|
+
"""
|
|
76
|
+
# --- determine directory ---
|
|
77
|
+
if output_dir is not None:
|
|
78
|
+
directory = Path(output_dir)
|
|
79
|
+
elif input_context.kind in (InputKind.SINGLE_STR, InputKind.LIST):
|
|
80
|
+
# No output_dir and in-memory input → stay in-memory.
|
|
81
|
+
return None
|
|
82
|
+
elif input_context.kind == InputKind.ITERABLE:
|
|
83
|
+
# Iterable with no output_dir → default to current working directory.
|
|
84
|
+
directory = Path.cwd()
|
|
85
|
+
else:
|
|
86
|
+
# FILE_CSV, FILE_TSV, FILE_TXT → same directory as input file.
|
|
87
|
+
assert input_context.input_path is not None
|
|
88
|
+
directory = input_context.input_path.parent
|
|
89
|
+
|
|
90
|
+
# --- determine base name ---
|
|
91
|
+
if output_name is not None:
|
|
92
|
+
name = output_name
|
|
93
|
+
elif input_context.kind in (InputKind.FILE_CSV, InputKind.FILE_TSV):
|
|
94
|
+
assert input_context.input_path is not None
|
|
95
|
+
name = f"{input_context.input_path.stem}_diversified"
|
|
96
|
+
elif input_context.kind == InputKind.FILE_TXT:
|
|
97
|
+
assert input_context.input_path is not None
|
|
98
|
+
name = input_context.input_path.stem
|
|
99
|
+
else:
|
|
100
|
+
# ITERABLE, or LIST/SINGLE_STR with output_dir.
|
|
101
|
+
name = "diversified_output"
|
|
102
|
+
|
|
103
|
+
# --- build final path with the correct extension ---
|
|
104
|
+
if not name.endswith(".jsonl"):
|
|
105
|
+
name = f"{name}.jsonl"
|
|
106
|
+
result = directory / name
|
|
107
|
+
|
|
108
|
+
_log.info("Output will be written to %s", result)
|
|
109
|
+
return result
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# ------------------------------------------------------------------
|
|
113
|
+
# Output writer
|
|
114
|
+
# ------------------------------------------------------------------
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class OutputWriter:
|
|
118
|
+
"""Incrementally writes diversify results to the right format.
|
|
119
|
+
|
|
120
|
+
Modes
|
|
121
|
+
-----
|
|
122
|
+
* **In-memory** (``output_path is None``): accumulates
|
|
123
|
+
``list[dict]`` with keys ``"original"`` and ``"paraphrases"``.
|
|
124
|
+
* **JSONL** (``output_path is not None``): writes one JSON object
|
|
125
|
+
per line to a ``.jsonl`` file.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
input_context: InputContext,
|
|
131
|
+
n_styles: int,
|
|
132
|
+
output_path: Path | None,
|
|
133
|
+
) -> None:
|
|
134
|
+
"""Initialize the writer.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
input_context : InputContext
|
|
139
|
+
Metadata about the input source (kind, path, etc.).
|
|
140
|
+
n_styles : int
|
|
141
|
+
Number of paraphrase styles requested per text.
|
|
142
|
+
output_path : Path or None
|
|
143
|
+
Where to write results on disk. ``None`` means results
|
|
144
|
+
are kept in memory and returned as ``list[dict]``.
|
|
145
|
+
"""
|
|
146
|
+
self._input_context = input_context
|
|
147
|
+
self._n_styles = n_styles
|
|
148
|
+
self._output_path = output_path
|
|
149
|
+
# Open file handle — set by open() when writing to disk.
|
|
150
|
+
self._handle: IO[str] | None = None
|
|
151
|
+
# In-memory accumulator — used only when output_path is None.
|
|
152
|
+
self._accumulated: list[dict[str, Any]] = []
|
|
153
|
+
|
|
154
|
+
# --- lifecycle: open / write / close ---
|
|
155
|
+
|
|
156
|
+
def open(self) -> None:
|
|
157
|
+
"""Open the file handle when writing to disk.
|
|
158
|
+
|
|
159
|
+
Must be called before :meth:`write_batch`.
|
|
160
|
+
|
|
161
|
+
* ``output_path is None`` — does nothing (in-memory mode).
|
|
162
|
+
* Otherwise — opens a single JSONL file for writing.
|
|
163
|
+
"""
|
|
164
|
+
if self._output_path is None:
|
|
165
|
+
# In-memory mode: nothing to open.
|
|
166
|
+
return
|
|
167
|
+
|
|
168
|
+
# Disk mode: open a single JSONL file.
|
|
169
|
+
self._output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
170
|
+
self._handle = open(self._output_path, "w", encoding="utf-8")
|
|
171
|
+
|
|
172
|
+
def write_batch(
|
|
173
|
+
self,
|
|
174
|
+
originals: list[str],
|
|
175
|
+
paraphrases_by_text: list[list[str]],
|
|
176
|
+
) -> None:
|
|
177
|
+
"""Append one batch of results.
|
|
178
|
+
|
|
179
|
+
Parameters
|
|
180
|
+
----------
|
|
181
|
+
originals : list[str]
|
|
182
|
+
The original texts in this batch.
|
|
183
|
+
paraphrases_by_text : list[list[str]]
|
|
184
|
+
One inner list per original text, each containing *n_styles*
|
|
185
|
+
paraphrased variants. For example, with 2 styles and 2
|
|
186
|
+
texts: ``[["a_style1", "a_style2"], ["b_style1", "b_style2"]]``.
|
|
187
|
+
Raises
|
|
188
|
+
------
|
|
189
|
+
ValueError
|
|
190
|
+
If ``originals`` and ``paraphrases_by_text`` have different
|
|
191
|
+
lengths.
|
|
192
|
+
"""
|
|
193
|
+
if len(originals) != len(paraphrases_by_text):
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"originals has {len(originals)} items but "
|
|
196
|
+
f"paraphrases_by_text has {len(paraphrases_by_text)}."
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
for i, (orig, paras) in enumerate(zip(originals, paraphrases_by_text)):
|
|
200
|
+
if len(paras) != self._n_styles:
|
|
201
|
+
_log.warning(
|
|
202
|
+
"Expected %d paraphrases for text %d, got %d.",
|
|
203
|
+
self._n_styles, i, len(paras),
|
|
204
|
+
)
|
|
205
|
+
record = {"original": orig, "paraphrases": paras}
|
|
206
|
+
if self._output_path is None:
|
|
207
|
+
self._accumulated.append(record)
|
|
208
|
+
else:
|
|
209
|
+
assert self._handle is not None
|
|
210
|
+
self._handle.write(
|
|
211
|
+
json.dumps(record, ensure_ascii=False) + "\n"
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
def finish(self) -> DiversifyOutput:
|
|
215
|
+
"""Close the file handle and return the final result.
|
|
216
|
+
|
|
217
|
+
Returns
|
|
218
|
+
-------
|
|
219
|
+
list[dict]
|
|
220
|
+
When ``output_path`` was ``None`` (in-memory mode). Each dict
|
|
221
|
+
has keys ``"original"`` and ``"paraphrases"``.
|
|
222
|
+
Path
|
|
223
|
+
When results were written to disk — the ``.jsonl`` path.
|
|
224
|
+
"""
|
|
225
|
+
if self._handle is not None:
|
|
226
|
+
self._handle.close()
|
|
227
|
+
self._handle = None
|
|
228
|
+
|
|
229
|
+
if self._output_path is None:
|
|
230
|
+
# In-memory mode: return the accumulated list of dicts.
|
|
231
|
+
return self._accumulated
|
|
232
|
+
|
|
233
|
+
# Disk mode: return the output path.
|
|
234
|
+
return self._output_path
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Text postprocessing utilities for diversify."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from diversify_text._preprocess import PreprocessContext
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def reassemble_segments(
|
|
9
|
+
segments_per_text: list[list[str]],
|
|
10
|
+
paraphrases_by_segment: list[list[str]],
|
|
11
|
+
) -> list[list[str]]:
|
|
12
|
+
"""Join per-segment paraphrases back into per-original-text paraphrases.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
segments_per_text : list[list[str]]
|
|
17
|
+
The sentence segments for each original text (from
|
|
18
|
+
:func:`~diversify_text._preprocess.split_sentences`).
|
|
19
|
+
paraphrases_by_segment : list[list[str]]
|
|
20
|
+
Flat list of paraphrases for every segment, shape
|
|
21
|
+
``[total_segments][n_styles]``.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
list[list[str]]
|
|
26
|
+
Shape ``[n_texts][n_styles]`` — reassembled paraphrases.
|
|
27
|
+
"""
|
|
28
|
+
result = []
|
|
29
|
+
seg_idx = 0
|
|
30
|
+
for segs in segments_per_text:
|
|
31
|
+
seg_paras = paraphrases_by_segment[seg_idx : seg_idx + len(segs)]
|
|
32
|
+
n_styles = len(seg_paras[0])
|
|
33
|
+
result.append([" ".join(sp[i] for sp in seg_paras) for i in range(n_styles)])
|
|
34
|
+
seg_idx += len(segs)
|
|
35
|
+
return result
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def postprocess(
|
|
39
|
+
candidate: list[list[str]],
|
|
40
|
+
context: PreprocessContext,
|
|
41
|
+
) -> list[list[str]]:
|
|
42
|
+
"""Undo preprocessing transformations on a candidate set.
|
|
43
|
+
|
|
44
|
+
Applies the inverse of each step performed by
|
|
45
|
+
:func:`~diversify_text._preprocess.preprocess`, using the state stored in
|
|
46
|
+
*context*.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
candidate : list[list[str]]
|
|
51
|
+
Raw generation output, shape ``[n_generation_texts][n_styles]``.
|
|
52
|
+
context : PreprocessContext
|
|
53
|
+
Context returned by :func:`~diversify_text._preprocess.preprocess`.
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
list[list[str]]
|
|
58
|
+
Shape ``[n_texts][n_styles]`` — one paraphrase per original text
|
|
59
|
+
per style.
|
|
60
|
+
"""
|
|
61
|
+
if context.segments_per_text is not None:
|
|
62
|
+
candidate = reassemble_segments(context.segments_per_text, candidate)
|
|
63
|
+
|
|
64
|
+
return candidate
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""Text preprocessing utilities for diversify."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
import pysbd # https://github.com/nipunsadvilkar/pySBD, published at EMNLP 2020, rule-based
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
_SEGMENTER = pysbd.Segmenter(language="en", clean=False)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class PreprocessContext:
|
|
15
|
+
"""State produced by :func:`preprocess` and consumed by
|
|
16
|
+
:func:`~diversify_text._postprocess.postprocess`.
|
|
17
|
+
|
|
18
|
+
New preprocessing steps can add fields here without changing the
|
|
19
|
+
caller in ``core.py``.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
segments_per_text: list[list[str]] | None = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def split_sentences(text: str) -> list[str]:
|
|
26
|
+
"""Split *text* into sentences using pysbd.
|
|
27
|
+
|
|
28
|
+
Returns a list of stripped sentence strings. If the text is empty or
|
|
29
|
+
whitespace-only, returns a single-element list containing the stripped
|
|
30
|
+
(possibly empty) input.
|
|
31
|
+
"""
|
|
32
|
+
segments = _SEGMENTER.segment(text.strip())
|
|
33
|
+
cleaned = [s.strip() for s in segments if s.strip()]
|
|
34
|
+
return cleaned or [text.strip()]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def preprocess(
|
|
38
|
+
texts: list[str],
|
|
39
|
+
*,
|
|
40
|
+
split_on_punctuation: bool = False,
|
|
41
|
+
) -> tuple[list[str], PreprocessContext]:
|
|
42
|
+
"""Prepare a batch of texts for generation.
|
|
43
|
+
|
|
44
|
+
Returns the (possibly transformed) texts to feed into the generation
|
|
45
|
+
method, together with a :class:`PreprocessContext` that
|
|
46
|
+
:func:`~diversify_text._postprocess.postprocess` needs to undo the
|
|
47
|
+
transformations.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
texts : list[str]
|
|
52
|
+
Original input texts.
|
|
53
|
+
split_on_punctuation : bool
|
|
54
|
+
If ``True``, split each text into sentence-level segments and
|
|
55
|
+
flatten the result. The per-text segment mapping is stored in
|
|
56
|
+
the context so that :func:`~diversify_text._postprocess.postprocess`
|
|
57
|
+
can reassemble them.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
generation_texts : list[str]
|
|
62
|
+
Texts to pass to the generation method.
|
|
63
|
+
context : PreprocessContext
|
|
64
|
+
Context needed by :func:`~diversify_text._postprocess.postprocess`.
|
|
65
|
+
"""
|
|
66
|
+
context = PreprocessContext()
|
|
67
|
+
|
|
68
|
+
if split_on_punctuation:
|
|
69
|
+
context.segments_per_text = [split_sentences(t) for t in texts]
|
|
70
|
+
generation_texts = [
|
|
71
|
+
seg for segs in context.segments_per_text for seg in segs
|
|
72
|
+
]
|
|
73
|
+
else:
|
|
74
|
+
generation_texts = texts
|
|
75
|
+
|
|
76
|
+
return generation_texts, context
|
diversify_text/_utils.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Shared internal utilities for diversify."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import contextlib
|
|
6
|
+
import logging
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@contextlib.contextmanager
|
|
11
|
+
def suppress_hf_load_noise():
|
|
12
|
+
"""Silence harmless noise emitted when loading HuggingFace models.
|
|
13
|
+
|
|
14
|
+
Covers two sources that Python's warnings module alone cannot reach:
|
|
15
|
+
|
|
16
|
+
- Tied-weights notices from the ``transformers`` logging system.
|
|
17
|
+
- Unexpected-key load reports from the style-embedding model.
|
|
18
|
+
"""
|
|
19
|
+
transformers_logger = logging.getLogger("transformers")
|
|
20
|
+
prev_level = transformers_logger.level
|
|
21
|
+
transformers_logger.setLevel(logging.ERROR)
|
|
22
|
+
try:
|
|
23
|
+
with warnings.catch_warnings():
|
|
24
|
+
warnings.filterwarnings("ignore", message=".*tie.*weight.*")
|
|
25
|
+
yield
|
|
26
|
+
finally:
|
|
27
|
+
transformers_logger.setLevel(prev_level)
|