waterfall 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- waterfall/WatermarkerBase.py +1 -0
- waterfall/watermark.py +31 -6
- {waterfall-0.1.3.dist-info → waterfall-0.1.4.dist-info}/METADATA +1 -1
- {waterfall-0.1.3.dist-info → waterfall-0.1.4.dist-info}/RECORD +7 -7
- {waterfall-0.1.3.dist-info → waterfall-0.1.4.dist-info}/WHEEL +0 -0
- {waterfall-0.1.3.dist-info → waterfall-0.1.4.dist-info}/entry_points.txt +0 -0
- {waterfall-0.1.3.dist-info → waterfall-0.1.4.dist-info}/licenses/LICENSE +0 -0
waterfall/WatermarkerBase.py
CHANGED
waterfall/watermark.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
import gc
|
|
4
5
|
import torch
|
|
5
6
|
from typing import List, Literal, Optional, Tuple
|
|
6
7
|
|
|
7
8
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
9
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
8
10
|
from sentence_transformers import SentenceTransformer
|
|
9
11
|
from tqdm.auto import tqdm
|
|
10
12
|
|
|
@@ -21,6 +23,8 @@ PROMPT = (
|
|
|
21
23
|
)
|
|
22
24
|
PRE_PARAPHRASED = "Here is a paraphrased version of the text while preserving the semantic similarity:\n\n"
|
|
23
25
|
|
|
26
|
+
waterfall_cached_watermarking_model = None # Global variable to cache the watermarking model
|
|
27
|
+
|
|
24
28
|
def detect_gpu() -> str:
|
|
25
29
|
"""
|
|
26
30
|
Use torch to detect if MPS, CUDA, or neither (default CPU)
|
|
@@ -42,9 +46,10 @@ def watermark(
|
|
|
42
46
|
sts_model: SentenceTransformer,
|
|
43
47
|
num_beam_groups: int = 4,
|
|
44
48
|
beams_per_group: int = 2,
|
|
45
|
-
STS_scale:float = 2.0,
|
|
49
|
+
STS_scale: float = 2.0,
|
|
46
50
|
diversity_penalty: float = 0.5,
|
|
47
51
|
max_new_tokens: Optional[int] = None,
|
|
52
|
+
**kwargs
|
|
48
53
|
) -> str:
|
|
49
54
|
paraphrasing_prompt = watermarker.tokenizer.apply_chat_template(
|
|
50
55
|
[
|
|
@@ -61,6 +66,7 @@ def watermark(
|
|
|
61
66
|
num_beam_groups = num_beam_groups,
|
|
62
67
|
num_return_sequences = num_beam_groups * beams_per_group,
|
|
63
68
|
diversity_penalty = diversity_penalty,
|
|
69
|
+
**kwargs,
|
|
64
70
|
)
|
|
65
71
|
|
|
66
72
|
# Select best paraphrasing based on q_score and semantic similarity
|
|
@@ -140,6 +146,7 @@ def watermark_texts(
|
|
|
140
146
|
diversity_penalty: float = 0.5,
|
|
141
147
|
STS_scale:float = 2.0,
|
|
142
148
|
use_tqdm: bool = False,
|
|
149
|
+
stop_at_double_newline: bool = True, # if True, will stop generation at the first double newline. Prevent repeated paraphrasing of the same text.
|
|
143
150
|
) -> List[str]:
|
|
144
151
|
if watermark_fn == 'fourier':
|
|
145
152
|
watermarkingFnClass = WatermarkingFnFourier
|
|
@@ -150,11 +157,25 @@ def watermark_texts(
|
|
|
150
157
|
|
|
151
158
|
if watermarker is None:
|
|
152
159
|
assert model_path is not None, "model_path must be provided if watermarker is not passed"
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
160
|
+
global waterfall_cached_watermarking_model
|
|
161
|
+
|
|
162
|
+
if isinstance(waterfall_cached_watermarking_model, PreTrainedModel) and waterfall_cached_watermarking_model.name_or_path != model_path:
|
|
163
|
+
device = waterfall_cached_watermarking_model.device.type
|
|
164
|
+
del waterfall_cached_watermarking_model
|
|
165
|
+
gc.collect()
|
|
166
|
+
if device == "cuda":
|
|
167
|
+
torch.cuda.empty_cache()
|
|
168
|
+
elif device == "mps":
|
|
169
|
+
torch.mps.empty_cache()
|
|
170
|
+
waterfall_cached_watermarking_model = None
|
|
171
|
+
|
|
172
|
+
if waterfall_cached_watermarking_model is None:
|
|
173
|
+
waterfall_cached_watermarking_model = AutoModelForCausalLM.from_pretrained(
|
|
174
|
+
model_path,
|
|
175
|
+
torch_dtype=torch_dtype,
|
|
176
|
+
device_map=device,
|
|
177
|
+
)
|
|
178
|
+
model = waterfall_cached_watermarking_model
|
|
158
179
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
159
180
|
|
|
160
181
|
watermarker = Watermarker(tokenizer=tokenizer, model=model, id=id, kappa=kappa, k_p=k_p, watermarkingFnClass=watermarkingFnClass)
|
|
@@ -173,6 +194,9 @@ def watermark_texts(
|
|
|
173
194
|
T_ws = []
|
|
174
195
|
|
|
175
196
|
for T_o in tqdm(T_os, desc="Watermarking texts", disable=not use_tqdm):
|
|
197
|
+
if stop_at_double_newline and "\n\n" in T_o:
|
|
198
|
+
logging.warning("Text contains \\n\\n and stop_at_double_newline is set to True, replacing all \\n\\n in text.")
|
|
199
|
+
T_o = T_o.replace("\n\n", " ") # replace double newlines with space
|
|
176
200
|
T_w = watermark(
|
|
177
201
|
T_o,
|
|
178
202
|
watermarker = watermarker,
|
|
@@ -181,6 +205,7 @@ def watermark_texts(
|
|
|
181
205
|
beams_per_group = beams_per_group,
|
|
182
206
|
diversity_penalty = diversity_penalty,
|
|
183
207
|
STS_scale = STS_scale,
|
|
208
|
+
stop_strings=["\n\n"] if stop_at_double_newline else None,
|
|
184
209
|
)
|
|
185
210
|
T_ws.append(T_w)
|
|
186
211
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: waterfall
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.4
|
|
4
4
|
Summary: Scalable Framework for Robust Text Watermarking and Provenance for LLMs
|
|
5
5
|
Project-URL: Homepage, https://github.com/aoi3142/Waterfall
|
|
6
6
|
Project-URL: Issues, https://github.com/aoi3142/Waterfall/issues
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
waterfall/WatermarkerBase.py,sha256=
|
|
1
|
+
waterfall/WatermarkerBase.py,sha256=AyScrZz3hdjikvz5Fm4-B4acDz46i5wDFwCBg6Fp-vY,12947
|
|
2
2
|
waterfall/WatermarkingFn.py,sha256=-b-kGRdL0a7eKRqJmcHPAR_rCjxQYnsg1Ne6bTwBc1I,1931
|
|
3
3
|
waterfall/WatermarkingFnFourier.py,sha256=QYayAQYwi1dQkDIyqmvhU568VhrVYTVy47HkI8F8SZs,1358
|
|
4
4
|
waterfall/WatermarkingFnSquare.py,sha256=2PAO05DdKT02npo7GDf_82D520nP7kGAWK6H4E4JMt4,1638
|
|
5
5
|
waterfall/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
6
|
waterfall/permute.py,sha256=RwxOHFhx_VSOhhFwy5s79YgwTUBkfW2-LCCXYR3VT2o,2582
|
|
7
|
-
waterfall/watermark.py,sha256=
|
|
8
|
-
waterfall-0.1.
|
|
9
|
-
waterfall-0.1.
|
|
10
|
-
waterfall-0.1.
|
|
11
|
-
waterfall-0.1.
|
|
12
|
-
waterfall-0.1.
|
|
7
|
+
waterfall/watermark.py,sha256=h7e1z8vWTUAKxCcQsJ2Jkx_1ZL-ug2dEDs5FzWcYfCs,13332
|
|
8
|
+
waterfall-0.1.4.dist-info/METADATA,sha256=3hBQwb1JyrTWrayLCPFxXVlTpPjuE-ukPstW5F9F9rg,8715
|
|
9
|
+
waterfall-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
10
|
+
waterfall-0.1.4.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
|
|
11
|
+
waterfall-0.1.4.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
|
|
12
|
+
waterfall-0.1.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|