diversify-text 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,24 @@
1
+ """diversify-text -- generate stylistic paraphrases of texts."""
2
+
3
+ import logging
4
+
5
+ from diversify_text.core import (
6
+ Diversifier,
7
+ diversify,
8
+ )
9
+
10
+ __all__ = [
11
+ "Diversifier",
12
+ "diversify",
13
+ ]
14
+
15
+ # Configure a clean handler for the diversify logger so INFO/WARNING messages
16
+ # are visible without requiring the user to set up logging themselves.
17
+ _logger = logging.getLogger("diversify_text")
18
+ _logger.setLevel(logging.INFO)
19
+ _handler = logging.StreamHandler()
20
+ _handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
21
+ _logger.addHandler(_handler)
22
+ # Prevent messages from bubbling up to the root logger (avoids duplicate output
23
+ # if the user has already configured logging globally).
24
+ _logger.propagate = False
@@ -0,0 +1,267 @@
1
+ """Input resolution for diversify.
2
+
3
+ Converts the many input forms users can provide (single string, list,
4
+ generator, CSV/TSV/TXT file path) into a uniform ``Iterator[str]`` plus
5
+ an :class:`InputContext` that describes the source.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import csv
11
+ import logging
12
+ from collections.abc import Iterable, Iterator
13
+ from dataclasses import dataclass
14
+ from enum import Enum, auto
15
+ from pathlib import Path
16
+ from typing import Union
17
+
18
+
19
+ # ------------------------------------------------------------------
20
+ # Type alias
21
+ # ------------------------------------------------------------------
22
+
23
+ TextInput = Union[str, "list[str]", "Iterable[str]"]
24
+
25
+ _log = logging.getLogger(__name__)
26
+
27
+
28
+ # ------------------------------------------------------------------
29
+ # Input kind & context
30
+ # ------------------------------------------------------------------
31
+
32
+
33
+ class InputKind(Enum):
34
+ """Discriminator for how the user provided input."""
35
+
36
+ SINGLE_STR = auto()
37
+ LIST = auto()
38
+ ITERABLE = auto()
39
+ FILE_CSV = auto()
40
+ FILE_TSV = auto()
41
+ FILE_TXT = auto()
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class InputContext:
46
+ """Read-only metadata about the resolved input source."""
47
+
48
+ kind: InputKind
49
+ input_path: Path | None = None
50
+ text_column: str | None = None
51
+ total: int | None = None
52
+
53
+
54
+ # ------------------------------------------------------------------
55
+ # Public API
56
+ # ------------------------------------------------------------------
57
+
58
+
59
+ def resolve_input(
60
+ texts: TextInput,
61
+ text_column: str = "text",
62
+ ) -> tuple[Iterator[str], InputContext]:
63
+ """Convert any supported input into a lazy ``Iterator[str]`` plus metadata.
64
+
65
+ Parameters
66
+ ----------
67
+ texts : str | list[str] | Iterable[str]
68
+ A single text, a list of texts, a generator / iterable of texts,
69
+ or a path to a ``.csv``, ``.tsv``, or ``.txt`` file.
70
+ text_column : str
71
+ Column name to extract when *texts* points to a CSV/TSV file.
72
+
73
+ Returns
74
+ -------
75
+ (Iterator[str], InputContext)
76
+ """
77
+ # --- str: could be a file path or a single text ---
78
+ if isinstance(texts, str):
79
+ path = Path(texts)
80
+ suffix = path.suffix.lower()
81
+
82
+ if suffix == ".csv" and path.is_file():
83
+ _validate_csv_header(path, text_column, delimiter=",")
84
+ total = _count_file_lines(path) - 1 # subtract header row
85
+ return _iter_csv(path, text_column, delimiter=","), InputContext(
86
+ kind=InputKind.FILE_CSV,
87
+ input_path=path,
88
+ text_column=text_column,
89
+ total=total,
90
+ )
91
+
92
+ if suffix == ".tsv" and path.is_file():
93
+ _validate_csv_header(path, text_column, delimiter="\t")
94
+ total = _count_file_lines(path) - 1 # subtract header row
95
+ return _iter_csv(path, text_column, delimiter="\t"), InputContext(
96
+ kind=InputKind.FILE_TSV,
97
+ input_path=path,
98
+ text_column=text_column,
99
+ total=total,
100
+ )
101
+
102
+ if suffix == ".txt" and path.is_file():
103
+ total = _count_nonempty_lines(path)
104
+ _log.warning(
105
+ "TXT file input: each line is treated as a separate text "
106
+ "to diversify. Newlines are never part of a parsed text. "
107
+ "(%d non-empty lines found in %s)",
108
+ total,
109
+ path,
110
+ )
111
+ return _iter_txt_lines(path), InputContext(
112
+ kind=InputKind.FILE_TXT,
113
+ input_path=path,
114
+ total=total,
115
+ )
116
+
117
+ # Not a recognized file — treat as a single text.
118
+ return iter([texts]), InputContext(kind=InputKind.SINGLE_STR, total=1)
119
+
120
+ # --- list[str] ---
121
+ if isinstance(texts, list):
122
+ return iter(texts), InputContext(kind=InputKind.LIST, total=len(texts))
123
+
124
+ # --- generic Iterable[str] (generators, file handles, …) ---
125
+ if isinstance(texts, Iterable):
126
+ return iter(texts), InputContext(kind=InputKind.ITERABLE, total=None)
127
+
128
+ raise TypeError(
129
+ f"Unsupported input type {type(texts).__name__}. "
130
+ "Expected str, list[str], or Iterable[str]."
131
+ )
132
+
133
+
134
+ # ------------------------------------------------------------------
135
+ # Validation & counting (private, cheap first-pass helpers)
136
+ # ------------------------------------------------------------------
137
+
138
+
139
+ def _validate_csv_header(path: Path, text_column: str, delimiter: str) -> None:
140
+ """Check that *text_column* exists in the CSV/TSV header.
141
+
142
+ Opens the file, reads only the header row, then closes it.
143
+
144
+ Parameters
145
+ ----------
146
+ path : Path
147
+ Path to the CSV or TSV file.
148
+ text_column : str
149
+ Expected column name.
150
+ delimiter : str
151
+ Field separator — ``","`` for CSV, ``"\\t"`` for TSV.
152
+
153
+ Raises
154
+ ------
155
+ ValueError
156
+ If *text_column* is not found among the file's header fields.
157
+ """
158
+ with open(path, newline="", encoding="utf-8") as f:
159
+ reader = csv.DictReader(f, delimiter=delimiter)
160
+ if reader.fieldnames is None or text_column not in reader.fieldnames:
161
+ available = ", ".join(reader.fieldnames or [])
162
+ raise ValueError(
163
+ f"Column '{text_column}' not found in {path}. "
164
+ f"Available: {available}"
165
+ )
166
+
167
+
168
+ def _count_file_lines(path: Path) -> int:
169
+ """Count the total number of lines in a file.
170
+
171
+ This is a cheap pass that only reads raw lines — no CSV parsing.
172
+ For CSV/TSV files, subtract 1 for the header to get the data-row
173
+ count. The count may slightly overestimate if the file contains
174
+ multi-line quoted CSV fields, but that is acceptable for a progress
175
+ bar.
176
+
177
+ Parameters
178
+ ----------
179
+ path : Path
180
+ Path to the file.
181
+
182
+ Returns
183
+ -------
184
+ int
185
+ Total number of lines.
186
+ """
187
+ with open(path, encoding="utf-8") as f:
188
+ return sum(1 for _ in f)
189
+
190
+
191
+ def _count_nonempty_lines(path: Path) -> int:
192
+ """Count non-empty lines in a file (skips blank / whitespace-only).
193
+
194
+ Parameters
195
+ ----------
196
+ path : Path
197
+ Path to the file.
198
+
199
+ Returns
200
+ -------
201
+ int
202
+ Number of non-empty lines.
203
+ """
204
+ with open(path, encoding="utf-8") as f:
205
+ return sum(1 for line in f if line.strip())
206
+
207
+
208
+ # ------------------------------------------------------------------
209
+ # Lazy file iterators (private)
210
+ # ------------------------------------------------------------------
211
+
212
+
213
+ def _iter_csv(path: Path, text_column: str, delimiter: str) -> Iterator[str]:
214
+ """Lazily yield *text_column* values from a CSV/TSV file.
215
+
216
+ Only one row is held in memory at a time. The file handle is
217
+ closed automatically when the generator is exhausted or garbage
218
+ collected.
219
+
220
+ Parameters
221
+ ----------
222
+ path : Path
223
+ Path to the CSV or TSV file.
224
+ text_column : str
225
+ Name of the column that contains the texts to diversify.
226
+ delimiter : str
227
+ Field separator — ``","`` for CSV, ``"\\t"`` for TSV.
228
+
229
+ Yields
230
+ ------
231
+ str
232
+ The text value from each row.
233
+ """
234
+ f = open(path, newline="", encoding="utf-8")
235
+ try:
236
+ reader = csv.DictReader(f, delimiter=delimiter)
237
+ for row in reader:
238
+ yield row.get(text_column) or ""
239
+ finally:
240
+ f.close()
241
+
242
+
243
+ def _iter_txt_lines(path: Path) -> Iterator[str]:
244
+ """Lazily yield non-empty, stripped lines from a ``.txt`` file.
245
+
246
+ Blank lines and whitespace-only lines are skipped. Only one line
247
+ is held in memory at a time. The file handle is closed
248
+ automatically when the generator is exhausted or garbage collected.
249
+
250
+ Parameters
251
+ ----------
252
+ path : Path
253
+ Path to the ``.txt`` file.
254
+
255
+ Yields
256
+ ------
257
+ str
258
+ Each non-empty line, with leading/trailing whitespace stripped.
259
+ """
260
+ f = open(path, encoding="utf-8")
261
+ try:
262
+ for line in f:
263
+ stripped = line.strip()
264
+ if stripped:
265
+ yield stripped
266
+ finally:
267
+ f.close()
@@ -0,0 +1,234 @@
1
+ """Output path resolution and incremental writing for diversify.
2
+
3
+ Decides *where* results go (in-memory vs. disk) and writes them in the
4
+ appropriate format (Python list or JSONL).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import logging
11
+ from pathlib import Path
12
+ from typing import IO, Any, Union
13
+
14
+ from diversify_text._input import InputContext, InputKind
15
+
16
+ _log = logging.getLogger(__name__)
17
+
18
+
19
+ # ------------------------------------------------------------------
20
+ # Type alias
21
+ # ------------------------------------------------------------------
22
+
23
+ DiversifyOutput = Union[list[dict], Path]
24
+
25
+
26
+ # ------------------------------------------------------------------
27
+ # Output path resolution
28
+ # ------------------------------------------------------------------
29
+
30
+
31
+ def resolve_output_path(
32
+ input_context: InputContext,
33
+ output_dir: str | Path | None = None,
34
+ output_name: str | None = None,
35
+ ) -> Path | None:
36
+ """Determine where output should be written, or ``None`` for in-memory.
37
+
38
+ The user controls *where* (directory) and *what name* (stem) to use,
39
+ but the **extension is always ``.jsonl``**.
40
+
41
+ Directory defaults (when *output_dir* is ``None``):
42
+
43
+ * ``SINGLE_STR`` / ``LIST`` → ``None`` (keep in memory, return as
44
+ ``list[dict]``). These are small, known-size inputs from Python
45
+ code, so the caller typically wants results as Python objects.
46
+ * ``ITERABLE`` → current working directory.
47
+ * ``FILE_CSV`` / ``FILE_TSV`` / ``FILE_TXT`` → same directory as
48
+ the input file.
49
+
50
+ If *output_dir* is provided for ``SINGLE_STR`` / ``LIST``, results
51
+ are written to disk instead of being returned in memory.
52
+
53
+ Name defaults (when *output_name* is ``None``):
54
+
55
+ * ``FILE_CSV`` / ``FILE_TSV`` → ``<input_stem>_diversified``
56
+ * ``FILE_TXT`` → ``<input_stem>``
57
+ * Everything else → ``diversified_output``
58
+
59
+ Parameters
60
+ ----------
61
+ input_context : InputContext
62
+ Metadata produced by :func:`resolve_input`.
63
+ output_dir : str, Path, or None
64
+ Directory to write output files into.
65
+ output_name : str or None
66
+ Base filename (without extension). The correct extension is
67
+ appended automatically. If the name already contains an
68
+ extension it is **not** stripped — the correct extension is
69
+ appended after it — unless it already ends with ``.jsonl``.
70
+
71
+ Returns
72
+ -------
73
+ Path or None
74
+ ``None`` means in-memory mode; otherwise the path to write to.
75
+ """
76
+ # --- determine directory ---
77
+ if output_dir is not None:
78
+ directory = Path(output_dir)
79
+ elif input_context.kind in (InputKind.SINGLE_STR, InputKind.LIST):
80
+ # No output_dir and in-memory input → stay in-memory.
81
+ return None
82
+ elif input_context.kind == InputKind.ITERABLE:
83
+ # Iterable with no output_dir → default to current working directory.
84
+ directory = Path.cwd()
85
+ else:
86
+ # FILE_CSV, FILE_TSV, FILE_TXT → same directory as input file.
87
+ assert input_context.input_path is not None
88
+ directory = input_context.input_path.parent
89
+
90
+ # --- determine base name ---
91
+ if output_name is not None:
92
+ name = output_name
93
+ elif input_context.kind in (InputKind.FILE_CSV, InputKind.FILE_TSV):
94
+ assert input_context.input_path is not None
95
+ name = f"{input_context.input_path.stem}_diversified"
96
+ elif input_context.kind == InputKind.FILE_TXT:
97
+ assert input_context.input_path is not None
98
+ name = input_context.input_path.stem
99
+ else:
100
+ # ITERABLE, or LIST/SINGLE_STR with output_dir.
101
+ name = "diversified_output"
102
+
103
+ # --- build final path with the correct extension ---
104
+ if not name.endswith(".jsonl"):
105
+ name = f"{name}.jsonl"
106
+ result = directory / name
107
+
108
+ _log.info("Output will be written to %s", result)
109
+ return result
110
+
111
+
112
+ # ------------------------------------------------------------------
113
+ # Output writer
114
+ # ------------------------------------------------------------------
115
+
116
+
117
+ class OutputWriter:
118
+ """Incrementally writes diversify results to the right format.
119
+
120
+ Modes
121
+ -----
122
+ * **In-memory** (``output_path is None``): accumulates
123
+ ``list[dict]`` with keys ``"original"`` and ``"paraphrases"``.
124
+ * **JSONL** (``output_path is not None``): writes one JSON object
125
+ per line to a ``.jsonl`` file.
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ input_context: InputContext,
131
+ n_styles: int,
132
+ output_path: Path | None,
133
+ ) -> None:
134
+ """Initialize the writer.
135
+
136
+ Parameters
137
+ ----------
138
+ input_context : InputContext
139
+ Metadata about the input source (kind, path, etc.).
140
+ n_styles : int
141
+ Number of paraphrase styles requested per text.
142
+ output_path : Path or None
143
+ Where to write results on disk. ``None`` means results
144
+ are kept in memory and returned as ``list[dict]``.
145
+ """
146
+ self._input_context = input_context
147
+ self._n_styles = n_styles
148
+ self._output_path = output_path
149
+ # Open file handle — set by open() when writing to disk.
150
+ self._handle: IO[str] | None = None
151
+ # In-memory accumulator — used only when output_path is None.
152
+ self._accumulated: list[dict[str, Any]] = []
153
+
154
+ # --- lifecycle: open / write / close ---
155
+
156
+ def open(self) -> None:
157
+ """Open the file handle when writing to disk.
158
+
159
+ Must be called before :meth:`write_batch`.
160
+
161
+ * ``output_path is None`` — does nothing (in-memory mode).
162
+ * Otherwise — opens a single JSONL file for writing.
163
+ """
164
+ if self._output_path is None:
165
+ # In-memory mode: nothing to open.
166
+ return
167
+
168
+ # Disk mode: open a single JSONL file.
169
+ self._output_path.parent.mkdir(parents=True, exist_ok=True)
170
+ self._handle = open(self._output_path, "w", encoding="utf-8")
171
+
172
+ def write_batch(
173
+ self,
174
+ originals: list[str],
175
+ paraphrases_by_text: list[list[str]],
176
+ ) -> None:
177
+ """Append one batch of results.
178
+
179
+ Parameters
180
+ ----------
181
+ originals : list[str]
182
+ The original texts in this batch.
183
+ paraphrases_by_text : list[list[str]]
184
+ One inner list per original text, each containing *n_styles*
185
+ paraphrased variants. For example, with 2 styles and 2
186
+ texts: ``[["a_style1", "a_style2"], ["b_style1", "b_style2"]]``.
187
+ Raises
188
+ ------
189
+ ValueError
190
+ If ``originals`` and ``paraphrases_by_text`` have different
191
+ lengths.
192
+ """
193
+ if len(originals) != len(paraphrases_by_text):
194
+ raise ValueError(
195
+ f"originals has {len(originals)} items but "
196
+ f"paraphrases_by_text has {len(paraphrases_by_text)}."
197
+ )
198
+
199
+ for i, (orig, paras) in enumerate(zip(originals, paraphrases_by_text)):
200
+ if len(paras) != self._n_styles:
201
+ _log.warning(
202
+ "Expected %d paraphrases for text %d, got %d.",
203
+ self._n_styles, i, len(paras),
204
+ )
205
+ record = {"original": orig, "paraphrases": paras}
206
+ if self._output_path is None:
207
+ self._accumulated.append(record)
208
+ else:
209
+ assert self._handle is not None
210
+ self._handle.write(
211
+ json.dumps(record, ensure_ascii=False) + "\n"
212
+ )
213
+
214
+ def finish(self) -> DiversifyOutput:
215
+ """Close the file handle and return the final result.
216
+
217
+ Returns
218
+ -------
219
+ list[dict]
220
+ When ``output_path`` was ``None`` (in-memory mode). Each dict
221
+ has keys ``"original"`` and ``"paraphrases"``.
222
+ Path
223
+ When results were written to disk — the ``.jsonl`` path.
224
+ """
225
+ if self._handle is not None:
226
+ self._handle.close()
227
+ self._handle = None
228
+
229
+ if self._output_path is None:
230
+ # In-memory mode: return the accumulated list of dicts.
231
+ return self._accumulated
232
+
233
+ # Disk mode: return the output path.
234
+ return self._output_path
@@ -0,0 +1,64 @@
1
+ """Text postprocessing utilities for diversify."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from diversify_text._preprocess import PreprocessContext
6
+
7
+
8
+ def reassemble_segments(
9
+ segments_per_text: list[list[str]],
10
+ paraphrases_by_segment: list[list[str]],
11
+ ) -> list[list[str]]:
12
+ """Join per-segment paraphrases back into per-original-text paraphrases.
13
+
14
+ Parameters
15
+ ----------
16
+ segments_per_text : list[list[str]]
17
+ The sentence segments for each original text (from
18
+ :func:`~diversify_text._preprocess.split_sentences`).
19
+ paraphrases_by_segment : list[list[str]]
20
+ Flat list of paraphrases for every segment, shape
21
+ ``[total_segments][n_styles]``.
22
+
23
+ Returns
24
+ -------
25
+ list[list[str]]
26
+ Shape ``[n_texts][n_styles]`` — reassembled paraphrases.
27
+ """
28
+ result = []
29
+ seg_idx = 0
30
+ for segs in segments_per_text:
31
+ seg_paras = paraphrases_by_segment[seg_idx : seg_idx + len(segs)]
32
+ n_styles = len(seg_paras[0])
33
+ result.append([" ".join(sp[i] for sp in seg_paras) for i in range(n_styles)])
34
+ seg_idx += len(segs)
35
+ return result
36
+
37
+
38
+ def postprocess(
39
+ candidate: list[list[str]],
40
+ context: PreprocessContext,
41
+ ) -> list[list[str]]:
42
+ """Undo preprocessing transformations on a candidate set.
43
+
44
+ Applies the inverse of each step performed by
45
+ :func:`~diversify_text._preprocess.preprocess`, using the state stored in
46
+ *context*.
47
+
48
+ Parameters
49
+ ----------
50
+ candidate : list[list[str]]
51
+ Raw generation output, shape ``[n_generation_texts][n_styles]``.
52
+ context : PreprocessContext
53
+ Context returned by :func:`~diversify_text._preprocess.preprocess`.
54
+
55
+ Returns
56
+ -------
57
+ list[list[str]]
58
+ Shape ``[n_texts][n_styles]`` — one paraphrase per original text
59
+ per style.
60
+ """
61
+ if context.segments_per_text is not None:
62
+ candidate = reassemble_segments(context.segments_per_text, candidate)
63
+
64
+ return candidate
@@ -0,0 +1,76 @@
1
+ """Text preprocessing utilities for diversify."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ import pysbd # https://github.com/nipunsadvilkar/pySBD, published at EMNLP 2020, rule-based
8
+
9
+
10
+ _SEGMENTER = pysbd.Segmenter(language="en", clean=False)
11
+
12
+
13
+ @dataclass
14
+ class PreprocessContext:
15
+ """State produced by :func:`preprocess` and consumed by
16
+ :func:`~diversify_text._postprocess.postprocess`.
17
+
18
+ New preprocessing steps can add fields here without changing the
19
+ caller in ``core.py``.
20
+ """
21
+
22
+ segments_per_text: list[list[str]] | None = None
23
+
24
+
25
+ def split_sentences(text: str) -> list[str]:
26
+ """Split *text* into sentences using pysbd.
27
+
28
+ Returns a list of stripped sentence strings. If the text is empty or
29
+ whitespace-only, returns a single-element list containing the stripped
30
+ (possibly empty) input.
31
+ """
32
+ segments = _SEGMENTER.segment(text.strip())
33
+ cleaned = [s.strip() for s in segments if s.strip()]
34
+ return cleaned or [text.strip()]
35
+
36
+
37
+ def preprocess(
38
+ texts: list[str],
39
+ *,
40
+ split_on_punctuation: bool = False,
41
+ ) -> tuple[list[str], PreprocessContext]:
42
+ """Prepare a batch of texts for generation.
43
+
44
+ Returns the (possibly transformed) texts to feed into the generation
45
+ method, together with a :class:`PreprocessContext` that
46
+ :func:`~diversify_text._postprocess.postprocess` needs to undo the
47
+ transformations.
48
+
49
+ Parameters
50
+ ----------
51
+ texts : list[str]
52
+ Original input texts.
53
+ split_on_punctuation : bool
54
+ If ``True``, split each text into sentence-level segments and
55
+ flatten the result. The per-text segment mapping is stored in
56
+ the context so that :func:`~diversify_text._postprocess.postprocess`
57
+ can reassemble them.
58
+
59
+ Returns
60
+ -------
61
+ generation_texts : list[str]
62
+ Texts to pass to the generation method.
63
+ context : PreprocessContext
64
+ Context needed by :func:`~diversify_text._postprocess.postprocess`.
65
+ """
66
+ context = PreprocessContext()
67
+
68
+ if split_on_punctuation:
69
+ context.segments_per_text = [split_sentences(t) for t in texts]
70
+ generation_texts = [
71
+ seg for segs in context.segments_per_text for seg in segs
72
+ ]
73
+ else:
74
+ generation_texts = texts
75
+
76
+ return generation_texts, context
@@ -0,0 +1,27 @@
1
+ """Shared internal utilities for diversify."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import contextlib
6
+ import logging
7
+ import warnings
8
+
9
+
10
+ @contextlib.contextmanager
11
+ def suppress_hf_load_noise():
12
+ """Silence harmless noise emitted when loading HuggingFace models.
13
+
14
+ Covers two sources that Python's warnings module alone cannot reach:
15
+
16
+ - Tied-weights notices from the ``transformers`` logging system.
17
+ - Unexpected-key load reports from the style-embedding model.
18
+ """
19
+ transformers_logger = logging.getLogger("transformers")
20
+ prev_level = transformers_logger.level
21
+ transformers_logger.setLevel(logging.ERROR)
22
+ try:
23
+ with warnings.catch_warnings():
24
+ warnings.filterwarnings("ignore", message=".*tie.*weight.*")
25
+ yield
26
+ finally:
27
+ transformers_logger.setLevel(prev_level)