waterfall 0.1.6__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {waterfall-0.1.6 → waterfall-0.2.0}/PKG-INFO +2 -2
- {waterfall-0.1.6 → waterfall-0.2.0}/pyproject.toml +2 -2
- {waterfall-0.1.6 → waterfall-0.2.0}/requirements.txt +1 -1
- waterfall-0.2.0/waterfall/WatermarkerBase.py +487 -0
- {waterfall-0.1.6 → waterfall-0.2.0}/waterfall/permute.py +9 -2
- {waterfall-0.1.6 → waterfall-0.2.0}/waterfall/watermark.py +131 -121
- waterfall-0.1.6/test.ipynb +0 -129
- waterfall-0.1.6/test_.ipynb +0 -118
- waterfall-0.1.6/waterfall/WatermarkerBase.py +0 -317
- {waterfall-0.1.6 → waterfall-0.2.0}/.gitignore +0 -0
- {waterfall-0.1.6 → waterfall-0.2.0}/Images/Illustration.gif +0 -0
- {waterfall-0.1.6 → waterfall-0.2.0}/Images/Problem_formulation.jpg +0 -0
- {waterfall-0.1.6 → waterfall-0.2.0}/Images/Watermarking_process.png +0 -0
- {waterfall-0.1.6 → waterfall-0.2.0}/LICENSE +0 -0
- {waterfall-0.1.6 → waterfall-0.2.0}/README.md +0 -0
- {waterfall-0.1.6 → waterfall-0.2.0}/__init__.py +0 -0
- {waterfall-0.1.6 → waterfall-0.2.0}/waterfall/WatermarkingFn.py +0 -0
- {waterfall-0.1.6 → waterfall-0.2.0}/waterfall/WatermarkingFnFourier.py +0 -0
- {waterfall-0.1.6 → waterfall-0.2.0}/waterfall/WatermarkingFnSquare.py +0 -0
- {waterfall-0.1.6 → waterfall-0.2.0}/waterfall/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: waterfall
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: Scalable Framework for Robust Text Watermarking and Provenance for LLMs
|
|
5
5
|
Project-URL: Homepage, https://github.com/aoi3142/Waterfall
|
|
6
6
|
Project-URL: Issues, https://github.com/aoi3142/Waterfall/issues
|
|
@@ -15,7 +15,7 @@ Requires-Dist: numpy>=2.0.0
|
|
|
15
15
|
Requires-Dist: scipy>=1.13.0
|
|
16
16
|
Requires-Dist: sentence-transformers>=3.0.0
|
|
17
17
|
Requires-Dist: torch>=2.3.0
|
|
18
|
-
Requires-Dist: transformers
|
|
18
|
+
Requires-Dist: transformers<4.55.0,>=4.43.1
|
|
19
19
|
Description-Content-Type: text/markdown
|
|
20
20
|
|
|
21
21
|
# Waterfall: Scalable Framework for Robust Text Watermarking and Provenance for LLMs [EMNLP 2024 Main Long]
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "waterfall"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.2.0"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name = "Xinyuan Niu", email="aperture@outlook.sg" }
|
|
10
10
|
]
|
|
@@ -24,7 +24,7 @@ dependencies = [
|
|
|
24
24
|
"scipy>=1.13.0",
|
|
25
25
|
"sentence-transformers>=3.0.0",
|
|
26
26
|
"torch>=2.3.0",
|
|
27
|
-
"transformers>=4.43.1",
|
|
27
|
+
"transformers>=4.43.1,<4.55.0",
|
|
28
28
|
]
|
|
29
29
|
|
|
30
30
|
[project.urls]
|
|
@@ -0,0 +1,487 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from functools import partial
|
|
7
|
+
from multiprocessing import Pool
|
|
8
|
+
from typing import List, Tuple, Optional
|
|
9
|
+
from itertools import repeat
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
from scipy.sparse import csr_matrix, vstack
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
16
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
17
|
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, BatchEncoding
|
|
18
|
+
from transformers.generation.logits_process import LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper
|
|
19
|
+
from transformers.generation.configuration_utils import GenerationConfig
|
|
20
|
+
|
|
21
|
+
from waterfall.permute import Permute
|
|
22
|
+
from waterfall.WatermarkingFn import WatermarkingFn
|
|
23
|
+
from waterfall.WatermarkingFnFourier import WatermarkingFnFourier
|
|
24
|
+
|
|
25
|
+
class PerturbationProcessor(LogitsProcessor):
|
|
26
|
+
def __init__(self,
|
|
27
|
+
N : int = 32000, # Vocab size
|
|
28
|
+
id : int = 0, # Watermark ID
|
|
29
|
+
) -> None:
|
|
30
|
+
|
|
31
|
+
self.id = id
|
|
32
|
+
self.N = N
|
|
33
|
+
self.init_token_count = None
|
|
34
|
+
self.phi = torch.zeros(N)
|
|
35
|
+
self.n_gram = 2
|
|
36
|
+
|
|
37
|
+
self.skip_watermark = False
|
|
38
|
+
|
|
39
|
+
self.permute = Permute(self.N)
|
|
40
|
+
|
|
41
|
+
def reset(self, n_gram : int = 2) -> None:
|
|
42
|
+
self.n_gram = n_gram
|
|
43
|
+
self.init_token_count = None
|
|
44
|
+
if torch.allclose(self.phi,torch.median(self.phi)):
|
|
45
|
+
self.skip_watermark = True
|
|
46
|
+
logging.warning(f"Generating without watermark as watermarking function is flat")
|
|
47
|
+
else:
|
|
48
|
+
self.skip_watermark = False
|
|
49
|
+
|
|
50
|
+
def set_phi(self, phi : np.ndarray) -> None:
|
|
51
|
+
self.phi = torch.from_numpy(phi)
|
|
52
|
+
|
|
53
|
+
def __call__(self, input_ids: torch.LongTensor,
|
|
54
|
+
scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
55
|
+
|
|
56
|
+
if self.skip_watermark:
|
|
57
|
+
return scores
|
|
58
|
+
|
|
59
|
+
if self.init_token_count is None:
|
|
60
|
+
self.init_token_count = input_ids.shape[1]
|
|
61
|
+
|
|
62
|
+
# Insufficient tokens generated for n-gram
|
|
63
|
+
if self.init_token_count + self.n_gram - 1 > input_ids.shape[1]:
|
|
64
|
+
return scores
|
|
65
|
+
|
|
66
|
+
# using numpy as PyTorch tensors doesn't hash properly for rng and dict key
|
|
67
|
+
prev_tokens = input_ids[:,-self.n_gram+1:].cpu().numpy()
|
|
68
|
+
|
|
69
|
+
permutations = (
|
|
70
|
+
self.permute.get_permutation(prev_tokens[i,:], self.id, cache=True)
|
|
71
|
+
for i in range(prev_tokens.shape[0])
|
|
72
|
+
)
|
|
73
|
+
perturbations = torch.stack([
|
|
74
|
+
self.phi[permutation] for permutation in permutations
|
|
75
|
+
])
|
|
76
|
+
scores[:,:self.N] += perturbations.to(device=scores.device, dtype=scores.dtype)
|
|
77
|
+
return scores
|
|
78
|
+
|
|
79
|
+
def indices_to_counts(N : int, dtype : np.dtype, indices : np.ndarray) -> csr_matrix:
|
|
80
|
+
counts = csr_matrix([np.bincount(j, minlength=N).astype(dtype) for j in indices])
|
|
81
|
+
return counts
|
|
82
|
+
|
|
83
|
+
class Watermarker:
|
|
84
|
+
def __init__(self,
|
|
85
|
+
tokenizer : Optional[PreTrainedTokenizerBase | str] = None,
|
|
86
|
+
model : Optional[PreTrainedModel | str] = None,
|
|
87
|
+
id : int = 0,
|
|
88
|
+
kappa : float = 6,
|
|
89
|
+
k_p : int = 1,
|
|
90
|
+
n_gram : int = 2,
|
|
91
|
+
watermarkingFnClass = WatermarkingFnFourier,
|
|
92
|
+
device = None,
|
|
93
|
+
) -> None:
|
|
94
|
+
assert kappa >= 0, f"kappa must be >= 0, value provided is {kappa}"
|
|
95
|
+
|
|
96
|
+
self.id = id
|
|
97
|
+
self.k_p = k_p
|
|
98
|
+
self.n_gram = n_gram
|
|
99
|
+
self.kappa = kappa
|
|
100
|
+
|
|
101
|
+
if tokenizer is None:
|
|
102
|
+
if isinstance(model, str):
|
|
103
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
|
104
|
+
elif isinstance(model, PreTrainedModel):
|
|
105
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
|
|
106
|
+
else:
|
|
107
|
+
raise NotImplementedError("tokenizer must be provided or model must be a string or PreTrainedModel")
|
|
108
|
+
elif isinstance(tokenizer, str):
|
|
109
|
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
|
110
|
+
else:
|
|
111
|
+
self.tokenizer = tokenizer
|
|
112
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
113
|
+
self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
|
114
|
+
|
|
115
|
+
self.N = self.tokenizer.vocab_size
|
|
116
|
+
|
|
117
|
+
self.logits_processor = PerturbationProcessor(N = self.N, id = self.id)
|
|
118
|
+
|
|
119
|
+
if isinstance(model, str):
|
|
120
|
+
self.load_model(model, device_map=device)
|
|
121
|
+
else:
|
|
122
|
+
self.model = model
|
|
123
|
+
|
|
124
|
+
assert (self.model is None) or isinstance(self.model, PreTrainedModel), f"model must be a transformers model, value provided is {type(self.model)}" # argument order for tokenizer and model were swapped since the original code
|
|
125
|
+
|
|
126
|
+
self.compute_phi(watermarkingFnClass)
|
|
127
|
+
|
|
128
|
+
def load_model(self, model_name_or_path : str, device_map : str = "auto"):
|
|
129
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
130
|
+
model_name_or_path,
|
|
131
|
+
device_map=device_map,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def compute_phi(self, watermarkingFnClass = WatermarkingFnFourier) -> None:
|
|
135
|
+
self.watermarking_fn: WatermarkingFn = watermarkingFnClass(id = self.id, k_p = self.k_p, N = self.N, kappa = self.kappa)
|
|
136
|
+
self.phi = self.watermarking_fn.phi
|
|
137
|
+
|
|
138
|
+
self.logits_processor.set_phi(self.phi)
|
|
139
|
+
|
|
140
|
+
# Format prompt(s) into chat template
|
|
141
|
+
def format_prompt(
|
|
142
|
+
self,
|
|
143
|
+
T_os : str | List[str],
|
|
144
|
+
system_prompt : Optional[str] = None,
|
|
145
|
+
assistant_prefill : Optional[str | List[str]] = "",
|
|
146
|
+
) -> str | List[str]:
|
|
147
|
+
if isinstance(system_prompt, str):
|
|
148
|
+
_system_prompt = {"role":"system", "content":system_prompt}
|
|
149
|
+
is_single = isinstance(T_os, str)
|
|
150
|
+
if is_single:
|
|
151
|
+
T_os = [T_os]
|
|
152
|
+
if not isinstance(assistant_prefill, list):
|
|
153
|
+
assistant_prefill = repeat(assistant_prefill, len(T_os))
|
|
154
|
+
else:
|
|
155
|
+
assert len(assistant_prefill) == len(T_os), "Length of assistant_prefill must match length of T_os"
|
|
156
|
+
formatted_prompts = []
|
|
157
|
+
for T_o, prefill in zip(T_os, assistant_prefill):
|
|
158
|
+
formatted_prompt : str = self.tokenizer.apply_chat_template(
|
|
159
|
+
[
|
|
160
|
+
_system_prompt,
|
|
161
|
+
{"role":"user", "content":T_o},
|
|
162
|
+
], tokenize=False, add_generation_prompt = True)
|
|
163
|
+
if prefill is not None:
|
|
164
|
+
formatted_prompt += prefill
|
|
165
|
+
formatted_prompts.append(formatted_prompt)
|
|
166
|
+
if is_single:
|
|
167
|
+
return formatted_prompts[0]
|
|
168
|
+
return formatted_prompts
|
|
169
|
+
|
|
170
|
+
# Find the largest batch size that fits in GPU memory
|
|
171
|
+
def find_largest_batch_size(
|
|
172
|
+
self,
|
|
173
|
+
tokd_inputs : List[BatchEncoding],
|
|
174
|
+
logits_processor : List[LogitsProcessor] = [],
|
|
175
|
+
**kwargs,
|
|
176
|
+
):
|
|
177
|
+
longest_idx = np.argmax([tokd_input["input_ids"].shape[-1] for tokd_input in tokd_inputs])
|
|
178
|
+
if "generation_config" in kwargs:
|
|
179
|
+
generation_config = GenerationConfig(**kwargs["generation_config"].to_dict()) # copy
|
|
180
|
+
max_new_tokens = generation_config.max_new_tokens
|
|
181
|
+
else:
|
|
182
|
+
generation_config = GenerationConfig(**kwargs)
|
|
183
|
+
max_new_tokens = kwargs.get("max_new_tokens", 2048)
|
|
184
|
+
generation_config.update(max_new_tokens=1)
|
|
185
|
+
input_ids = tokd_inputs[longest_idx]["input_ids"]
|
|
186
|
+
input_ids = torch.zeros(
|
|
187
|
+
(1, max_new_tokens + input_ids.shape[-1] - 1),
|
|
188
|
+
dtype=input_ids.dtype,
|
|
189
|
+
device=self.model.device
|
|
190
|
+
)
|
|
191
|
+
max_batch_size = 1
|
|
192
|
+
with torch.no_grad():
|
|
193
|
+
while max_batch_size < min(16, len(tokd_inputs)):
|
|
194
|
+
torch.cuda.empty_cache()
|
|
195
|
+
try:
|
|
196
|
+
_ = self.model.generate(
|
|
197
|
+
input_ids=input_ids,
|
|
198
|
+
attention_mask=torch.ones_like(input_ids),
|
|
199
|
+
logits_processor=logits_processor,
|
|
200
|
+
generation_config=generation_config,
|
|
201
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
202
|
+
tokenizer=self.tokenizer,
|
|
203
|
+
)
|
|
204
|
+
max_batch_size = input_ids.shape[0]
|
|
205
|
+
except RuntimeError as e:
|
|
206
|
+
if "CUDA out of memory" in str(e):
|
|
207
|
+
break
|
|
208
|
+
else:
|
|
209
|
+
raise e
|
|
210
|
+
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
|
211
|
+
torch.cuda.empty_cache()
|
|
212
|
+
return max_batch_size
|
|
213
|
+
|
|
214
|
+
def generate(
|
|
215
|
+
self,
|
|
216
|
+
prompts : Optional[str | List[str]] = None,
|
|
217
|
+
tokd_inputs : Optional[torch.Tensor | List[torch.Tensor] | BatchEncoding | List[BatchEncoding]] = None,
|
|
218
|
+
n_gram : Optional[int] = None,
|
|
219
|
+
return_text : bool = True,
|
|
220
|
+
return_tokens : bool = False,
|
|
221
|
+
return_scores : bool = False,
|
|
222
|
+
use_tqdm : bool = False,
|
|
223
|
+
batched_generate : bool = True,
|
|
224
|
+
**kwargs # Other generate parameters
|
|
225
|
+
) -> List[str] | dict: # Returns flattened list of query x beam
|
|
226
|
+
|
|
227
|
+
assert self.model is not None, "Model is not loaded. Please load the model before generating text."
|
|
228
|
+
|
|
229
|
+
is_single = isinstance(prompts, str) or isinstance(tokd_inputs, torch.Tensor)
|
|
230
|
+
if is_single:
|
|
231
|
+
prompts = [prompts] if prompts is not None else None
|
|
232
|
+
tokd_inputs = [tokd_inputs] if tokd_inputs is not None else None
|
|
233
|
+
|
|
234
|
+
if n_gram is None:
|
|
235
|
+
n_gram = self.n_gram
|
|
236
|
+
if tokd_inputs is None:
|
|
237
|
+
assert prompts is not None, "Either prompt or tokd_input must be provided."
|
|
238
|
+
tokd_inputs = [self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False) for prompt in prompts]
|
|
239
|
+
|
|
240
|
+
# If tokd_input is a tensor, convert it to a BatchEncoding
|
|
241
|
+
squeezed_tokd_inputs = []
|
|
242
|
+
for tokd_input in tokd_inputs:
|
|
243
|
+
if isinstance(tokd_input, torch.Tensor):
|
|
244
|
+
input_ids = tokd_input
|
|
245
|
+
attention_mask = torch.ones_like(tokd_input)
|
|
246
|
+
else:
|
|
247
|
+
input_ids = tokd_input["input_ids"]
|
|
248
|
+
attention_mask = tokd_input["attention_mask"]
|
|
249
|
+
if input_ids.ndim == 2:
|
|
250
|
+
input_ids = input_ids.squeeze()
|
|
251
|
+
attention_mask = attention_mask.squeeze()
|
|
252
|
+
squeezed_tokd_inputs.append(BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask}))
|
|
253
|
+
tokd_inputs = squeezed_tokd_inputs
|
|
254
|
+
|
|
255
|
+
logits_processor = []
|
|
256
|
+
# Ensure top_k and top_p happens before watermarking
|
|
257
|
+
if "generation_config" in kwargs:
|
|
258
|
+
generation_config: GenerationConfig = kwargs["generation_config"]
|
|
259
|
+
top_k = generation_config.top_k
|
|
260
|
+
top_p = generation_config.top_p
|
|
261
|
+
generation_config.update(top_p=1.0)
|
|
262
|
+
else:
|
|
263
|
+
top_k = kwargs.pop("top_k", None)
|
|
264
|
+
top_p = kwargs.pop("top_p", None)
|
|
265
|
+
|
|
266
|
+
if top_k is not None and top_k != 0:
|
|
267
|
+
logits_processor.append(TopKLogitsWarper(top_k))
|
|
268
|
+
if top_p is not None and top_p < 1.0:
|
|
269
|
+
logits_processor.append(TopPLogitsWarper(top_p))
|
|
270
|
+
if self.kappa != 0:
|
|
271
|
+
logits_processor.append(self.logits_processor)
|
|
272
|
+
|
|
273
|
+
if batched_generate and len(tokd_inputs) >= 8:
|
|
274
|
+
max_batch_size = self.find_largest_batch_size(tokd_inputs, logits_processor=logits_processor, **kwargs)
|
|
275
|
+
else:
|
|
276
|
+
max_batch_size = 1
|
|
277
|
+
|
|
278
|
+
# Group inputs by token length
|
|
279
|
+
if max_batch_size > 1:
|
|
280
|
+
tokd_inputs_order = sorted(range(len(tokd_inputs)), key=lambda i: tokd_inputs[i]["input_ids"].shape[-1])
|
|
281
|
+
tokd_inputs = [tokd_inputs[i] for i in tokd_inputs_order]
|
|
282
|
+
else:
|
|
283
|
+
tokd_inputs_order = range(len(tokd_inputs))
|
|
284
|
+
tokd_input_batches = []
|
|
285
|
+
for i in range(0, len(tokd_inputs), max_batch_size):
|
|
286
|
+
batch = self.tokenizer.pad(tokd_inputs[i:i+max_batch_size], padding=True, padding_side="left").to(self.model.device, non_blocking=True)
|
|
287
|
+
tokd_input_batches.append(batch)
|
|
288
|
+
torch.cuda.synchronize()
|
|
289
|
+
|
|
290
|
+
outputs = []
|
|
291
|
+
with torch.no_grad():
|
|
292
|
+
bar = tqdm(total=len(tokd_inputs), desc="Generating text", disable=not use_tqdm)
|
|
293
|
+
for tokd_input_batch in tokd_input_batches:
|
|
294
|
+
self.logits_processor.reset(n_gram)
|
|
295
|
+
output = self.model.generate(
|
|
296
|
+
**tokd_input_batch,
|
|
297
|
+
logits_processor=logits_processor,
|
|
298
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
299
|
+
tokenizer=self.tokenizer,
|
|
300
|
+
**kwargs
|
|
301
|
+
)
|
|
302
|
+
output = output[:,tokd_input_batch["input_ids"].shape[-1]:].to("cpu", non_blocking=True)
|
|
303
|
+
outputs.append(output)
|
|
304
|
+
bar.update(tokd_input_batch["input_ids"].shape[0])
|
|
305
|
+
torch.cuda.synchronize()
|
|
306
|
+
outputs = [j for i in outputs for j in i] # Flatten the list of outputs
|
|
307
|
+
|
|
308
|
+
# Restore original ordering
|
|
309
|
+
if max_batch_size > 1:
|
|
310
|
+
reordered_outputs = [None] * len(outputs)
|
|
311
|
+
num_return_sequences = len(outputs) // len(tokd_inputs)
|
|
312
|
+
for i, idx in enumerate(tokd_inputs_order):
|
|
313
|
+
reordered_outputs[idx * num_return_sequences:(idx + 1) * num_return_sequences] = outputs[i * num_return_sequences:(i + 1) * num_return_sequences]
|
|
314
|
+
outputs = reordered_outputs
|
|
315
|
+
|
|
316
|
+
return_dict = {}
|
|
317
|
+
|
|
318
|
+
if return_scores:
|
|
319
|
+
cumulative_token_count = self.get_cumulative_token_count(self.id, outputs, n_gram = n_gram, return_dense=False)
|
|
320
|
+
cumulative_token_count = vstack([i[0] for i in cumulative_token_count], format="csr")
|
|
321
|
+
q_score, _, _ = self.watermarking_fn.q(cumulative_token_count, k_p = [self.k_p], use_tqdm=False)
|
|
322
|
+
return_dict["q_score"] = q_score
|
|
323
|
+
|
|
324
|
+
if return_tokens:
|
|
325
|
+
return_dict["tokens"] = outputs
|
|
326
|
+
|
|
327
|
+
if return_text:
|
|
328
|
+
decoded_output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
329
|
+
decoded_output = [i.strip() for i in decoded_output]
|
|
330
|
+
return_dict["text"] = decoded_output
|
|
331
|
+
|
|
332
|
+
if is_single:
|
|
333
|
+
for k, v in return_dict.items():
|
|
334
|
+
return_dict[k] = v[0]
|
|
335
|
+
|
|
336
|
+
if return_text and len(return_dict) == 1:
|
|
337
|
+
return decoded_output
|
|
338
|
+
|
|
339
|
+
return return_dict
|
|
340
|
+
|
|
341
|
+
def get_cumulative_token_count(
|
|
342
|
+
self,
|
|
343
|
+
ids : List[int] | int,
|
|
344
|
+
all_tokens : List[torch.Tensor] | torch.Tensor | List[np.ndarray] | np.ndarray | List[List[int]] | List[int],
|
|
345
|
+
n_gram : int = 2,
|
|
346
|
+
return_unshuffled_indices : bool = False,
|
|
347
|
+
use_tqdm : bool = False,
|
|
348
|
+
return_dense : bool = True,
|
|
349
|
+
batch_size : int = 2**8,
|
|
350
|
+
) -> List[csr_matrix] | List[np.ndarray] | Tuple[List[csr_matrix], List[List[np.ndarray]]] | Tuple[List[np.ndarray], List[List[np.ndarray]]]:
|
|
351
|
+
if isinstance(ids, int):
|
|
352
|
+
ids = [ids]
|
|
353
|
+
if isinstance(all_tokens[0], int) or (isinstance(all_tokens, (np.ndarray, torch.Tensor)) and all_tokens.ndim == 1):
|
|
354
|
+
all_tokens = [all_tokens]
|
|
355
|
+
all_tokens = list(map(lambda x: x.cpu().numpy() if isinstance(x, torch.Tensor) else x, all_tokens))
|
|
356
|
+
max_length = max(map(len, all_tokens))
|
|
357
|
+
window = n_gram - 1
|
|
358
|
+
|
|
359
|
+
# Collect all unique seeds for psuedo-random number generation
|
|
360
|
+
key_index_dict = defaultdict(set)
|
|
361
|
+
all_keys = []
|
|
362
|
+
for i, tokens in enumerate(tqdm(all_tokens, desc="Collecting unique n-grams", disable=not use_tqdm)):
|
|
363
|
+
all_keys.append([])
|
|
364
|
+
for j in range(window, len(tokens)):
|
|
365
|
+
prev_token = tuple(tokens[j-window:j])
|
|
366
|
+
t = tokens[j]
|
|
367
|
+
if t >= self.N:
|
|
368
|
+
break
|
|
369
|
+
key_index_dict[prev_token].add(t)
|
|
370
|
+
all_keys[i].append((prev_token, t))
|
|
371
|
+
key_index_dict = {k:tuple(v) for k,v in key_index_dict.items()}
|
|
372
|
+
|
|
373
|
+
use_mp = len(all_tokens) > batch_size * 4
|
|
374
|
+
if use_mp:
|
|
375
|
+
p = Pool(len(os.sched_getaffinity(0))-1)
|
|
376
|
+
pool_map = partial(p.imap, chunksize=batch_size)
|
|
377
|
+
else:
|
|
378
|
+
pool_map = map
|
|
379
|
+
|
|
380
|
+
# Generate permutations for all unique seeds
|
|
381
|
+
permutations = pool_map(
|
|
382
|
+
partial(self.logits_processor.permute.get_unshuffled_indices, ids),
|
|
383
|
+
key_index_dict.items())
|
|
384
|
+
permutations = tqdm(permutations, total=len(key_index_dict), desc="Getting permutations", disable=not use_tqdm)
|
|
385
|
+
for k, value in zip(key_index_dict.keys(), permutations):
|
|
386
|
+
key_index_dict[k] = value
|
|
387
|
+
|
|
388
|
+
# Assign indices to unshuffled_indices
|
|
389
|
+
unshuffled_indices: List[np.ndarray] = [] # [text x id x length]
|
|
390
|
+
for keys in tqdm(all_keys, desc="Assigning indices", disable=not use_tqdm):
|
|
391
|
+
if len(keys) == 0:
|
|
392
|
+
unshuffled_indices.append(np.zeros((len(ids), 0), dtype=np.min_scalar_type(self.N)))
|
|
393
|
+
else:
|
|
394
|
+
unshuffled_indices.append(np.stack([key_index_dict[key][t] for key, t in keys]).T) # [id x length]
|
|
395
|
+
|
|
396
|
+
# Convert indices to counts
|
|
397
|
+
cumulative_token_count = pool_map(
|
|
398
|
+
partial(indices_to_counts, self.N, np.min_scalar_type(max_length)),
|
|
399
|
+
unshuffled_indices
|
|
400
|
+
)
|
|
401
|
+
cumulative_token_count = list(tqdm(cumulative_token_count, total=len(unshuffled_indices), desc="Counting tokens", disable=not use_tqdm))
|
|
402
|
+
|
|
403
|
+
if use_mp:
|
|
404
|
+
p.close()
|
|
405
|
+
p.join()
|
|
406
|
+
|
|
407
|
+
if return_dense:
|
|
408
|
+
cumulative_token_count = list(map(lambda x: x.toarray(), cumulative_token_count))
|
|
409
|
+
|
|
410
|
+
if return_unshuffled_indices:
|
|
411
|
+
return cumulative_token_count, unshuffled_indices
|
|
412
|
+
return cumulative_token_count
|
|
413
|
+
|
|
414
|
+
def verify(
|
|
415
|
+
self,
|
|
416
|
+
text : str | List[str],
|
|
417
|
+
id: Optional[int | List[int]] = None,
|
|
418
|
+
k_p : Optional[int | List[int]] = None,
|
|
419
|
+
return_ranking : bool = False,
|
|
420
|
+
return_extracted_k_p : bool = False,
|
|
421
|
+
return_counts : bool = False,
|
|
422
|
+
return_unshuffled_indices : bool = False,
|
|
423
|
+
use_tqdm : bool = False,
|
|
424
|
+
batch_size : int = 2**8,
|
|
425
|
+
) -> np.ndarray | dict:
|
|
426
|
+
begin_time = time.time()
|
|
427
|
+
|
|
428
|
+
if id is None:
|
|
429
|
+
id = self.id
|
|
430
|
+
|
|
431
|
+
if isinstance(text, str):
|
|
432
|
+
texts = [text]
|
|
433
|
+
else:
|
|
434
|
+
texts = text
|
|
435
|
+
|
|
436
|
+
tokens = [np.array(self.tokenizer.encode(text, add_special_tokens=False), dtype=np.uint32) for text in tqdm(texts, desc="Tokenizing", disable=not use_tqdm)]
|
|
437
|
+
|
|
438
|
+
if isinstance(id, int):
|
|
439
|
+
ids = [id]
|
|
440
|
+
else:
|
|
441
|
+
ids = id
|
|
442
|
+
|
|
443
|
+
if k_p is None:
|
|
444
|
+
k_p = self.k_p
|
|
445
|
+
|
|
446
|
+
if isinstance(k_p, int):
|
|
447
|
+
k_ps = [k_p]
|
|
448
|
+
else:
|
|
449
|
+
k_ps = k_p
|
|
450
|
+
|
|
451
|
+
# Get cummulative token counts
|
|
452
|
+
start_time = time.time()
|
|
453
|
+
results = self.get_cumulative_token_count(ids, tokens, self.n_gram, return_unshuffled_indices, use_tqdm=use_tqdm, return_dense=False, batch_size=batch_size)
|
|
454
|
+
gc.collect()
|
|
455
|
+
if return_unshuffled_indices:
|
|
456
|
+
results, unshuffled_indices = results
|
|
457
|
+
results = vstack(results, format="csr")
|
|
458
|
+
if use_tqdm:
|
|
459
|
+
tqdm.write(f"Cummulative token counts done in {time.time() - start_time:.2f} seconds")
|
|
460
|
+
|
|
461
|
+
# Calculate Q score via dot product
|
|
462
|
+
start_time = time.time()
|
|
463
|
+
q_score, ranking, k_p_extracted = self.watermarking_fn.q(results, k_p = k_ps, batch = batch_size, use_tqdm = use_tqdm)
|
|
464
|
+
q_score, ranking = [i.reshape(-1, len(ids), i.shape[-1]) for i in (q_score, ranking)] # [text x ids x k_p for i in (score, rank)]
|
|
465
|
+
k_p_extracted = k_p_extracted.reshape(-1, len(ids)) # [text x ids]
|
|
466
|
+
if use_tqdm:
|
|
467
|
+
tqdm.write(f"Q score calculated in {time.time() - start_time:.2f} seconds")
|
|
468
|
+
|
|
469
|
+
res = q_score # [text x ids x k_p]
|
|
470
|
+
|
|
471
|
+
if return_ranking or return_extracted_k_p or return_counts or return_unshuffled_indices:
|
|
472
|
+
res = {
|
|
473
|
+
"q_score": q_score, # [text x ids x k_p]
|
|
474
|
+
}
|
|
475
|
+
if return_ranking:
|
|
476
|
+
res["ranking"] = ranking # [text x ids x k_p]
|
|
477
|
+
if return_extracted_k_p:
|
|
478
|
+
res["k_p_extracted"] = k_p_extracted # [text x ids]
|
|
479
|
+
if return_counts:
|
|
480
|
+
res["counts"] = results # [text x ids x k_p]
|
|
481
|
+
if return_unshuffled_indices:
|
|
482
|
+
res["unshuffled_indices"] = unshuffled_indices # [text x ids x length]
|
|
483
|
+
|
|
484
|
+
if use_tqdm:
|
|
485
|
+
tqdm.write(f"Total time taken for verify: {time.time() - begin_time:.2f} seconds")
|
|
486
|
+
|
|
487
|
+
return res
|
|
@@ -49,16 +49,23 @@ class Permute:
|
|
|
49
49
|
size_per_permutation_in_bytes = N * self.dtype.itemsize
|
|
50
50
|
cache_size = int(psutil.virtual_memory().total * 0.02 / size_per_permutation_in_bytes) # 2% of total memory
|
|
51
51
|
self.permutations.capacity = cache_size
|
|
52
|
+
self.no_permutation = np.arange(self.N, dtype=self.dtype)
|
|
53
|
+
|
|
54
|
+
def _permute(self, key):
|
|
55
|
+
return np.random.RandomState(key).permutation(self.N).astype(self.dtype)
|
|
52
56
|
|
|
53
57
|
def get_permutation(self, prev_tok, id : int, cache : bool = False) -> np.ndarray:
|
|
58
|
+
# Skip special tokens
|
|
59
|
+
if any((i >= self.N for i in prev_tok)):
|
|
60
|
+
return self.no_permutation
|
|
54
61
|
key = (id, *prev_tok)
|
|
55
62
|
if cache:
|
|
56
63
|
permutation = self.permutations.get(key)
|
|
57
64
|
if permutation is None:
|
|
58
|
-
permutation =
|
|
65
|
+
permutation = self._permute(key)
|
|
59
66
|
self.permutations.put(key, permutation)
|
|
60
67
|
else:
|
|
61
|
-
permutation =
|
|
68
|
+
permutation = self._permute(key)
|
|
62
69
|
return permutation
|
|
63
70
|
|
|
64
71
|
def get_unshuffled_indices(self, ids, args) -> dict[int, np.ndarray]:
|