waterfall 0.1.7__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,14 +6,17 @@ from collections import defaultdict
6
6
  from functools import partial
7
7
  from multiprocessing import Pool
8
8
  from typing import List, Tuple, Optional
9
+ from itertools import repeat
9
10
 
10
11
  import numpy as np
11
12
  import torch
12
13
  from scipy.sparse import csr_matrix, vstack
13
14
  from tqdm import tqdm
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
16
  from transformers.modeling_utils import PreTrainedModel
15
- from transformers.tokenization_utils_base import PreTrainedTokenizerBase
17
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase, BatchEncoding
16
18
  from transformers.generation.logits_process import LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper
19
+ from transformers.generation.configuration_utils import GenerationConfig
17
20
 
18
21
  from waterfall.permute import Permute
19
22
  from waterfall.WatermarkingFn import WatermarkingFn
@@ -28,7 +31,7 @@ class PerturbationProcessor(LogitsProcessor):
28
31
  self.id = id
29
32
  self.N = N
30
33
  self.init_token_count = None
31
- self.phi = np.ones(N)
34
+ self.phi = torch.zeros(N)
32
35
  self.n_gram = 2
33
36
 
34
37
  self.skip_watermark = False
@@ -38,14 +41,14 @@ class PerturbationProcessor(LogitsProcessor):
38
41
  def reset(self, n_gram : int = 2) -> None:
39
42
  self.n_gram = n_gram
40
43
  self.init_token_count = None
41
- if np.allclose(self.phi,np.median(self.phi)):
44
+ if torch.allclose(self.phi,torch.median(self.phi)):
42
45
  self.skip_watermark = True
43
46
  logging.warning(f"Generating without watermark as watermarking function is flat")
44
47
  else:
45
48
  self.skip_watermark = False
46
49
 
47
50
  def set_phi(self, phi : np.ndarray) -> None:
48
- self.phi = phi
51
+ self.phi = torch.from_numpy(phi)
49
52
 
50
53
  def __call__(self, input_ids: torch.LongTensor,
51
54
  scores: torch.FloatTensor) -> torch.FloatTensor:
@@ -60,12 +63,17 @@ class PerturbationProcessor(LogitsProcessor):
60
63
  if self.init_token_count + self.n_gram - 1 > input_ids.shape[1]:
61
64
  return scores
62
65
 
66
+ # using numpy as PyTorch tensors doesn't hash properly for rng and dict key
63
67
  prev_tokens = input_ids[:,-self.n_gram+1:].cpu().numpy()
64
- permutations = [self.permute.get_permutation(prev_tokens[i,:], self.id, cache=True) for i in range(prev_tokens.shape[0])]
65
68
 
66
- scores[:,:self.N] += torch.tensor(self.phi[permutations],
67
- device=scores.device,
68
- dtype=scores.dtype)
69
+ permutations = (
70
+ self.permute.get_permutation(prev_tokens[i,:], self.id, cache=True)
71
+ for i in range(prev_tokens.shape[0])
72
+ )
73
+ perturbations = torch.stack([
74
+ self.phi[permutation] for permutation in permutations
75
+ ])
76
+ scores[:,:self.N] += perturbations.to(device=scores.device, dtype=scores.dtype)
69
77
  return scores
70
78
 
71
79
  def indices_to_counts(N : int, dtype : np.dtype, indices : np.ndarray) -> csr_matrix:
@@ -74,99 +82,258 @@ def indices_to_counts(N : int, dtype : np.dtype, indices : np.ndarray) -> csr_ma
74
82
 
75
83
  class Watermarker:
76
84
  def __init__(self,
77
- tokenizer : PreTrainedTokenizerBase,
78
- model : Optional[PreTrainedModel] = None,
85
+ tokenizer : Optional[PreTrainedTokenizerBase | str] = None,
86
+ model : Optional[PreTrainedModel | str] = None,
79
87
  id : int = 0,
80
88
  kappa : float = 6,
81
89
  k_p : int = 1,
82
90
  n_gram : int = 2,
83
- watermarkingFnClass = WatermarkingFnFourier
91
+ watermarkingFnClass = WatermarkingFnFourier,
92
+ device = None,
84
93
  ) -> None:
85
94
  assert kappa >= 0, f"kappa must be >= 0, value provided is {kappa}"
86
95
 
87
- assert (model is None) or isinstance(model, PreTrainedModel), f"model must be a transformers model, value provided is {type(model)}" # argument order for tokenizer and model were swapped since the original code
88
-
89
- self.tokenizer = tokenizer
90
- self.model = model
91
96
  self.id = id
92
97
  self.k_p = k_p
93
98
  self.n_gram = n_gram
94
99
  self.kappa = kappa
95
100
 
101
+ if tokenizer is None:
102
+ if isinstance(model, str):
103
+ self.tokenizer = AutoTokenizer.from_pretrained(model)
104
+ elif isinstance(model, PreTrainedModel):
105
+ self.tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
106
+ else:
107
+ raise NotImplementedError("tokenizer must be provided or model must be a string or PreTrainedModel")
108
+ elif isinstance(tokenizer, str):
109
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
110
+ else:
111
+ self.tokenizer = tokenizer
112
+ self.tokenizer.pad_token = self.tokenizer.eos_token
113
+ self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
114
+
96
115
  self.N = self.tokenizer.vocab_size
97
- self.logits_processor = PerturbationProcessor(N = self.N, id = id)
116
+
117
+ self.logits_processor = PerturbationProcessor(N = self.N, id = self.id)
118
+
119
+ if isinstance(model, str):
120
+ self.load_model(model, device_map=device)
121
+ else:
122
+ self.model = model
123
+
124
+ assert (self.model is None) or isinstance(self.model, PreTrainedModel), f"model must be a transformers model, value provided is {type(self.model)}" # argument order for tokenizer and model were swapped since the original code
98
125
 
99
126
  self.compute_phi(watermarkingFnClass)
100
127
 
128
+ def load_model(self, model_name_or_path : str, device_map : str = "auto"):
129
+ self.model = AutoModelForCausalLM.from_pretrained(
130
+ model_name_or_path,
131
+ device_map=device_map,
132
+ )
133
+
101
134
  def compute_phi(self, watermarkingFnClass = WatermarkingFnFourier) -> None:
102
- self.watermarking_fn: WatermarkingFn = watermarkingFnClass(id = id, k_p = self.k_p, N = self.N, kappa = self.kappa)
135
+ self.watermarking_fn: WatermarkingFn = watermarkingFnClass(id = self.id, k_p = self.k_p, N = self.N, kappa = self.kappa)
103
136
  self.phi = self.watermarking_fn.phi
104
137
 
105
138
  self.logits_processor.set_phi(self.phi)
106
139
 
140
+ # Format prompt(s) into chat template
141
+ def format_prompt(
142
+ self,
143
+ T_os : str | List[str],
144
+ system_prompt : Optional[str] = None,
145
+ assistant_prefill : Optional[str | List[str]] = "",
146
+ ) -> str | List[str]:
147
+ if isinstance(system_prompt, str):
148
+ _system_prompt = {"role":"system", "content":system_prompt}
149
+ is_single = isinstance(T_os, str)
150
+ if is_single:
151
+ T_os = [T_os]
152
+ if not isinstance(assistant_prefill, list):
153
+ assistant_prefill = repeat(assistant_prefill, len(T_os))
154
+ else:
155
+ assert len(assistant_prefill) == len(T_os), "Length of assistant_prefill must match length of T_os"
156
+ formatted_prompts = []
157
+ for T_o, prefill in zip(T_os, assistant_prefill):
158
+ formatted_prompt : str = self.tokenizer.apply_chat_template(
159
+ [
160
+ _system_prompt,
161
+ {"role":"user", "content":T_o},
162
+ ], tokenize=False, add_generation_prompt = True)
163
+ if prefill is not None:
164
+ formatted_prompt += prefill
165
+ formatted_prompts.append(formatted_prompt)
166
+ if is_single:
167
+ return formatted_prompts[0]
168
+ return formatted_prompts
169
+
170
+ # Find the largest batch size that fits in GPU memory
171
+ def find_largest_batch_size(
172
+ self,
173
+ tokd_inputs : List[BatchEncoding],
174
+ logits_processor : List[LogitsProcessor] = [],
175
+ **kwargs,
176
+ ):
177
+ longest_idx = np.argmax([tokd_input["input_ids"].shape[-1] for tokd_input in tokd_inputs])
178
+ if "generation_config" in kwargs:
179
+ generation_config = GenerationConfig(**kwargs["generation_config"].to_dict()) # copy
180
+ max_new_tokens = generation_config.max_new_tokens
181
+ else:
182
+ generation_config = GenerationConfig(**kwargs)
183
+ max_new_tokens = kwargs.get("max_new_tokens", 2048)
184
+ generation_config.update(max_new_tokens=1)
185
+ input_ids = tokd_inputs[longest_idx]["input_ids"]
186
+ input_ids = torch.zeros(
187
+ (1, max_new_tokens + input_ids.shape[-1] - 1),
188
+ dtype=input_ids.dtype,
189
+ device=self.model.device
190
+ )
191
+ max_batch_size = 1
192
+ with torch.no_grad():
193
+ while max_batch_size < min(16, len(tokd_inputs)):
194
+ torch.cuda.empty_cache()
195
+ try:
196
+ _ = self.model.generate(
197
+ input_ids=input_ids,
198
+ attention_mask=torch.ones_like(input_ids),
199
+ logits_processor=logits_processor,
200
+ generation_config=generation_config,
201
+ pad_token_id=self.tokenizer.eos_token_id,
202
+ tokenizer=self.tokenizer,
203
+ )
204
+ max_batch_size = input_ids.shape[0]
205
+ except RuntimeError as e:
206
+ if "CUDA out of memory" in str(e):
207
+ break
208
+ else:
209
+ raise e
210
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
211
+ torch.cuda.empty_cache()
212
+ return max_batch_size
213
+
107
214
  def generate(
108
215
  self,
109
- prompt : Optional[str] = None,
110
- tokd_input : Optional[torch.Tensor] = None,
216
+ prompts : Optional[str | List[str]] = None,
217
+ tokd_inputs : Optional[torch.Tensor | List[torch.Tensor] | BatchEncoding | List[BatchEncoding]] = None,
111
218
  n_gram : Optional[int] = None,
112
- max_new_tokens : int = 1000,
113
- return_text : bool =True,
114
- return_tokens : bool =False,
115
- return_scores : bool =False,
116
- do_sample : bool =True,
117
- **kwargs
118
- ) -> List[str] | dict:
219
+ return_text : bool = True,
220
+ return_tokens : bool = False,
221
+ return_scores : bool = False,
222
+ use_tqdm : bool = False,
223
+ batched_generate : bool = True,
224
+ **kwargs # Other generate parameters
225
+ ) -> List[str] | dict: # Returns flattened list of query x beam
119
226
 
120
227
  assert self.model is not None, "Model is not loaded. Please load the model before generating text."
121
228
 
229
+ is_single = isinstance(prompts, str) or isinstance(tokd_inputs, torch.Tensor)
230
+ if is_single:
231
+ prompts = [prompts] if prompts is not None else None
232
+ tokd_inputs = [tokd_inputs] if tokd_inputs is not None else None
233
+
122
234
  if n_gram is None:
123
235
  n_gram = self.n_gram
124
- if tokd_input is None:
125
- assert prompt is not None, "Either prompt or tokd_input must be provided."
126
- tokd_input = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
127
- tokd_input = tokd_input.to(self.model.device)
236
+ if tokd_inputs is None:
237
+ assert prompts is not None, "Either prompt or tokd_input must be provided."
238
+ tokd_inputs = [self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False) for prompt in prompts]
239
+
240
+ # If tokd_input is a tensor, convert it to a BatchEncoding
241
+ squeezed_tokd_inputs = []
242
+ for tokd_input in tokd_inputs:
243
+ if isinstance(tokd_input, torch.Tensor):
244
+ input_ids = tokd_input
245
+ attention_mask = torch.ones_like(tokd_input)
246
+ else:
247
+ input_ids = tokd_input["input_ids"]
248
+ attention_mask = tokd_input["attention_mask"]
249
+ if input_ids.ndim == 2:
250
+ input_ids = input_ids.squeeze()
251
+ attention_mask = attention_mask.squeeze()
252
+ squeezed_tokd_inputs.append(BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask}))
253
+ tokd_inputs = squeezed_tokd_inputs
254
+
128
255
  logits_processor = []
129
- if "top_k" in kwargs and kwargs["top_k"] is not None and kwargs["top_k"] != 0:
130
- logits_processor.append(TopKLogitsWarper(kwargs.pop("top_k")))
131
- if "top_p" in kwargs and kwargs["top_p"] is not None and kwargs["top_p"] < 1.0:
132
- logits_processor.append(TopPLogitsWarper(kwargs.pop("top_p")))
256
+ # Ensure top_k and top_p happens before watermarking
257
+ if "generation_config" in kwargs:
258
+ generation_config: GenerationConfig = kwargs["generation_config"]
259
+ top_k = generation_config.top_k
260
+ top_p = generation_config.top_p
261
+ generation_config.update(top_p=1.0)
262
+ else:
263
+ top_k = kwargs.pop("top_k", None)
264
+ top_p = kwargs.pop("top_p", None)
265
+
266
+ if top_k is not None and top_k != 0:
267
+ logits_processor.append(TopKLogitsWarper(top_k))
268
+ if top_p is not None and top_p < 1.0:
269
+ logits_processor.append(TopPLogitsWarper(top_p))
133
270
  if self.kappa != 0:
134
271
  logits_processor.append(self.logits_processor)
135
272
 
273
+ if batched_generate and len(tokd_inputs) >= 8:
274
+ max_batch_size = self.find_largest_batch_size(tokd_inputs, logits_processor=logits_processor, **kwargs)
275
+ else:
276
+ max_batch_size = 1
277
+
278
+ # Group inputs by token length
279
+ if max_batch_size > 1:
280
+ tokd_inputs_order = sorted(range(len(tokd_inputs)), key=lambda i: tokd_inputs[i]["input_ids"].shape[-1])
281
+ tokd_inputs = [tokd_inputs[i] for i in tokd_inputs_order]
282
+ else:
283
+ tokd_inputs_order = range(len(tokd_inputs))
284
+ tokd_input_batches = []
285
+ for i in range(0, len(tokd_inputs), max_batch_size):
286
+ batch = self.tokenizer.pad(tokd_inputs[i:i+max_batch_size], padding=True, padding_side="left").to(self.model.device, non_blocking=True)
287
+ tokd_input_batches.append(batch)
288
+ torch.cuda.synchronize()
289
+
290
+ outputs = []
136
291
  with torch.no_grad():
137
- self.logits_processor.reset(n_gram)
138
- output = self.model.generate(
139
- **tokd_input,
140
- max_new_tokens=max_new_tokens,
141
- do_sample=do_sample,
142
- logits_processor=logits_processor,
143
- pad_token_id=self.tokenizer.eos_token_id,
144
- tokenizer=self.tokenizer,
145
- **kwargs
146
- )
147
- output = output[:,tokd_input["input_ids"].shape[-1]:].cpu()
292
+ bar = tqdm(total=len(tokd_inputs), desc="Generating text", disable=not use_tqdm)
293
+ for tokd_input_batch in tokd_input_batches:
294
+ self.logits_processor.reset(n_gram)
295
+ output = self.model.generate(
296
+ **tokd_input_batch,
297
+ logits_processor=logits_processor,
298
+ pad_token_id=self.tokenizer.eos_token_id,
299
+ tokenizer=self.tokenizer,
300
+ **kwargs
301
+ )
302
+ output = output[:,tokd_input_batch["input_ids"].shape[-1]:].to("cpu", non_blocking=True)
303
+ outputs.append(output)
304
+ bar.update(tokd_input_batch["input_ids"].shape[0])
305
+ torch.cuda.synchronize()
306
+ outputs = [j for i in outputs for j in i] # Flatten the list of outputs
307
+
308
+ # Restore original ordering
309
+ if max_batch_size > 1:
310
+ reordered_outputs = [None] * len(outputs)
311
+ num_return_sequences = len(outputs) // len(tokd_inputs)
312
+ for i, idx in enumerate(tokd_inputs_order):
313
+ reordered_outputs[idx * num_return_sequences:(idx + 1) * num_return_sequences] = outputs[i * num_return_sequences:(i + 1) * num_return_sequences]
314
+ outputs = reordered_outputs
148
315
 
149
316
  return_dict = {}
150
317
 
151
318
  if return_scores:
152
- cumulative_token_count = self.get_cumulative_token_count(self.id, output, n_gram = n_gram, return_dense=False)
319
+ cumulative_token_count = self.get_cumulative_token_count(self.id, outputs, n_gram = n_gram, return_dense=False)
153
320
  cumulative_token_count = vstack([i[0] for i in cumulative_token_count], format="csr")
154
321
  q_score, _, _ = self.watermarking_fn.q(cumulative_token_count, k_p = [self.k_p], use_tqdm=False)
155
- return_dict["q_score"] = q_score[:,0]
322
+ return_dict["q_score"] = q_score
156
323
 
157
324
  if return_tokens:
158
- return_dict["tokens"] = output
325
+ return_dict["tokens"] = outputs
159
326
 
160
327
  if return_text:
161
- decoded_output = self.tokenizer.batch_decode(output, skip_special_tokens=True)
328
+ decoded_output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
162
329
  decoded_output = [i.strip() for i in decoded_output]
163
330
  return_dict["text"] = decoded_output
164
331
 
165
- if len(output) == 1:
332
+ if is_single:
166
333
  for k, v in return_dict.items():
167
334
  return_dict[k] = v[0]
168
335
 
169
- if return_text and len(return_dict) == 0:
336
+ if return_text and len(return_dict) == 1:
170
337
  return decoded_output
171
338
 
172
339
  return return_dict
waterfall/permute.py CHANGED
@@ -49,16 +49,23 @@ class Permute:
49
49
  size_per_permutation_in_bytes = N * self.dtype.itemsize
50
50
  cache_size = int(psutil.virtual_memory().total * 0.02 / size_per_permutation_in_bytes) # 2% of total memory
51
51
  self.permutations.capacity = cache_size
52
+ self.no_permutation = np.arange(self.N, dtype=self.dtype)
53
+
54
+ def _permute(self, key):
55
+ return np.random.RandomState(key).permutation(self.N).astype(self.dtype)
52
56
 
53
57
  def get_permutation(self, prev_tok, id : int, cache : bool = False) -> np.ndarray:
58
+ # Skip special tokens
59
+ if any((i >= self.N for i in prev_tok)):
60
+ return self.no_permutation
54
61
  key = (id, *prev_tok)
55
62
  if cache:
56
63
  permutation = self.permutations.get(key)
57
64
  if permutation is None:
58
- permutation = np.random.RandomState(key).permutation(self.N).astype(self.dtype)
65
+ permutation = self._permute(key)
59
66
  self.permutations.put(key, permutation)
60
67
  else:
61
- permutation = np.random.RandomState(key).permutation(self.N).astype(self.dtype)
68
+ permutation = self._permute(key)
62
69
  return permutation
63
70
 
64
71
  def get_unshuffled_indices(self, ids, args) -> dict[int, np.ndarray]:
waterfall/watermark.py CHANGED
@@ -8,8 +8,8 @@ from typing import List, Literal, Optional, Tuple
8
8
 
9
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
10
  from transformers.modeling_utils import PreTrainedModel
11
+ from transformers.generation.configuration_utils import GenerationConfig
11
12
  from sentence_transformers import SentenceTransformer
12
- from tqdm.auto import tqdm
13
13
 
14
14
  from waterfall.WatermarkingFnFourier import WatermarkingFnFourier
15
15
  from waterfall.WatermarkingFnSquare import WatermarkingFnSquare
@@ -24,7 +24,7 @@ PROMPT = (
24
24
  )
25
25
  PRE_PARAPHRASED = "Here is a paraphrased version of the text while preserving the semantic similarity:\n\n"
26
26
 
27
- waterfall_cached_watermarking_model = None # Global variable to cache the watermarking model
27
+ waterfall_cached_watermarking_model: PreTrainedModel | None = None # Global variable to cache the watermarking model
28
28
 
29
29
  def detect_gpu() -> str:
30
30
  """
@@ -41,47 +41,137 @@ def detect_gpu() -> str:
41
41
  else:
42
42
  return 'cpu'
43
43
 
44
- def watermark(
45
- T_o: str,
46
- watermarker: Watermarker,
47
- sts_model: SentenceTransformer,
44
+ def watermark_texts(
45
+ T_os: List[str],
46
+ id: Optional[int] = None,
47
+ k_p: int = 1,
48
+ kappa: float = 2.0,
49
+ model_path: Optional[str] = "meta-llama/Llama-3.1-8B-Instruct",
50
+ sts_model_path: Optional[str] = "sentence-transformers/all-mpnet-base-v2",
51
+ watermark_fn: Literal["fourier", "square"] = "fourier",
52
+ watermarker: Optional[Watermarker] = None,
53
+ sts_model: Optional[SentenceTransformer] = None,
54
+ device: str = detect_gpu(),
55
+ STS_scale: float = 2.0,
56
+ use_tqdm: bool = False,
57
+ do_sample: bool = False,
58
+ temperature: Optional[float] = None,
59
+ top_p: Optional[float] = None,
60
+ max_new_tokens: Optional[int] = None,
48
61
  num_beam_groups: int = 4,
49
62
  beams_per_group: int = 2,
50
- STS_scale: float = 2.0,
51
63
  diversity_penalty: float = 0.5,
52
- max_new_tokens: Optional[int] = None,
53
- **kwargs
54
- ) -> str:
55
- paraphrasing_prompt = watermarker.tokenizer.apply_chat_template(
56
- [
57
- {"role":"system", "content":PROMPT},
58
- {"role":"user", "content":T_o},
59
- ], tokenize=False, add_generation_prompt = True) + PRE_PARAPHRASED
64
+ stop_at_double_newline: bool = True, # if True, will stop generation at the first double newline. Prevent repeated paraphrasing of the same text.
65
+ ) -> List[str]:
66
+ if watermark_fn == 'fourier':
67
+ watermarkingFnClass = WatermarkingFnFourier
68
+ elif watermark_fn == 'square':
69
+ watermarkingFnClass = WatermarkingFnSquare
70
+ else:
71
+ raise ValueError("Invalid watermarking function")
72
+
73
+ # Check if watermarker/model/tokenizer are loaded
74
+ if watermarker is None:
75
+ assert model_path is not None, "model_path must be provided if watermarker is not passed"
76
+ assert id is not None, "id must be provided if watermarker is not passed"
77
+ global waterfall_cached_watermarking_model
78
+
79
+ if isinstance(waterfall_cached_watermarking_model, PreTrainedModel) and waterfall_cached_watermarking_model.name_or_path != model_path:
80
+ device = waterfall_cached_watermarking_model.device.type
81
+ waterfall_cached_watermarking_model = None
82
+ gc.collect()
83
+ if device == "cuda":
84
+ torch.cuda.empty_cache()
85
+ elif device == "mps":
86
+ torch.mps.empty_cache()
87
+
88
+ if waterfall_cached_watermarking_model is None:
89
+ model = model_path
90
+ else:
91
+ model = waterfall_cached_watermarking_model
92
+
93
+ watermarker = Watermarker(model=model, id=id, kappa=kappa, k_p=k_p, watermarkingFnClass=watermarkingFnClass)
94
+ else:
95
+ device = watermarker.model.device.type
96
+ id = watermarker.id
97
+ waterfall_cached_watermarking_model = watermarker.model
98
+
99
+ # Check if sts model is loaded
100
+ if sts_model is None:
101
+ assert sts_model_path is not None, "sts_model_path must be provided if sts_model is not passed"
102
+ sts_model = SentenceTransformer(sts_model_path, device=device)
103
+
104
+ # Replace all \n\n in source text if stop_at_double_newline is True
105
+ # Models tend to generate \n\n before endlessly repeating itself, so we want to stop the model from doing that
106
+ if stop_at_double_newline:
107
+ for i in range(len(T_os)):
108
+ if "\n\n" in T_os[i]:
109
+ logging.warning(f"Text idx {i} contains \\n\\n and stop_at_double_newline is set to True, replacing all \\n\\n in text.")
110
+ T_os[i] = T_os[i].replace("\n\n", " ") # replace double newlines with space
111
+
112
+ # Add system prompt and prefill, and format into appropriate chat format
113
+ formatted_T_os = watermarker.format_prompt(
114
+ T_os,
115
+ system_prompt=PROMPT,
116
+ assistant_prefill=PRE_PARAPHRASED,
117
+ )
118
+
119
+ if max_new_tokens is None:
120
+ max_input_len = max(len(p) for p in formatted_T_os)
121
+ max_new_tokens = max_input_len
122
+
123
+ if do_sample:
124
+ assert (do_sample and temperature is not None and top_p is not None and num_beam_groups == 1 and beams_per_group == 1), \
125
+ "do_sample=True requires temperature, top_p, num_beam_groups=1 and beams_per_group=1"
126
+ else: # Using beam search
127
+ assert (not do_sample and temperature is None and top_p is None and num_beam_groups >= 1 and beams_per_group >= 1), \
128
+ "do_sample=False requires temperature=None, top_p=None, num_beam_groups>=1 and beams_per_group>=1"
129
+
130
+ eos_token_id = watermarker.tokenizer.eos_token_id
131
+ # add "\n\n" tokens to eos_token_id list
132
+ if stop_at_double_newline:
133
+ eos_token_id = [eos_token_id]
134
+ # llama tokenizer's .vocab() has weird symbols and doesn't work with GenerationConfig's stop_strings, so we have to brute force check all tokens
135
+ for token_id,string in enumerate(watermarker.tokenizer.batch_decode(torch.arange(watermarker.tokenizer.vocab_size).unsqueeze(1))):
136
+ if "\n\n" in string:
137
+ eos_token_id.append(token_id)
138
+
139
+ generation_config = GenerationConfig(
140
+ max_new_tokens=max_new_tokens,
141
+ do_sample=do_sample,
142
+ temperature=temperature,
143
+ top_p=top_p,
144
+ num_beam_groups=num_beam_groups,
145
+ num_beams=num_beam_groups * beams_per_group,
146
+ diversity_penalty=diversity_penalty,
147
+ eos_token_id=eos_token_id,
148
+ num_return_sequences=num_beam_groups * beams_per_group,
149
+ )
60
150
 
61
151
  watermarked = watermarker.generate(
62
- paraphrasing_prompt,
63
- return_scores = True,
64
- max_new_tokens = int(len(paraphrasing_prompt) * 1.5) if max_new_tokens is None else max_new_tokens,
65
- do_sample = False, temperature=None, top_p=None,
66
- num_beams = num_beam_groups * beams_per_group,
67
- num_beam_groups = num_beam_groups,
68
- num_return_sequences = num_beam_groups * beams_per_group,
69
- diversity_penalty = diversity_penalty,
70
- **kwargs,
71
- )
152
+ prompts=formatted_T_os,
153
+ return_text=True,
154
+ return_scores=True,
155
+ use_tqdm=use_tqdm,
156
+ generation_config=generation_config,
157
+ )
158
+ T_ws = watermarked["text"]
159
+ # Reshape T_ws to Queries X Beams
160
+ num_beams = num_beam_groups * beams_per_group
161
+ T_ws = [T_ws[i * num_beams:(i + 1) * num_beams] for i in range(len(T_os))]
72
162
 
73
163
  # Select best paraphrasing based on q_score and semantic similarity
74
- sts_scores = STS_scorer(T_o, watermarked["text"], sts_model)
75
- selection_score = sts_scores * STS_scale + torch.from_numpy(watermarked["q_score"])
76
- selection = torch.argmax(selection_score)
164
+ sts_scores = STS_scorer_batch(T_os, T_ws, sts_model)
165
+ selection_scores = sts_scores * STS_scale + torch.from_numpy(watermarked["q_score"]).reshape(-1, num_beams)
166
+ selections = torch.argmax(selection_scores, dim = -1)
77
167
 
78
- T_w = watermarked["text"][selection]
168
+ T_ws = [T_w[selection] for T_w, selection in zip(T_ws, selections)]
79
169
 
80
- return T_w
170
+ return T_ws
81
171
 
82
- def verify_texts(texts: List[str], id: int,
83
- watermarker: Optional[Watermarker] = None,
84
- k_p: Optional[int] = None,
172
+ def verify_texts(texts: List[str], id: int,
173
+ watermarker: Optional[Watermarker] = None,
174
+ k_p: Optional[int] = None,
85
175
  model_path: Optional[str] = "meta-llama/Llama-3.1-8B-Instruct",
86
176
  return_extracted_k_p: bool = False
87
177
  ) -> np.ndarray | Tuple[np.ndarray,np.ndarray]:
@@ -89,9 +179,8 @@ def verify_texts(texts: List[str], id: int,
89
179
 
90
180
  if watermarker is None:
91
181
  assert model_path is not None, "model_path must be provided if watermarker is not passed"
92
- tokenizer = AutoTokenizer.from_pretrained(model_path)
93
- watermarker = Watermarker(tokenizer=tokenizer)
94
-
182
+ watermarker = Watermarker(tokenizer=model_path)
183
+
95
184
  if k_p is None:
96
185
  k_p = watermarker.k_p
97
186
 
@@ -135,87 +224,6 @@ def STS_scorer(
135
224
  cos_sim = cos_sim.item()
136
225
  return cos_sim
137
226
 
138
- def watermark_texts(
139
- T_os: List[str],
140
- id: Optional[int] = None,
141
- k_p: int = 1,
142
- kappa: float = 2.0,
143
- model_path: str = "meta-llama/Llama-3.1-8B-Instruct",
144
- torch_dtype: torch.dtype = torch.bfloat16,
145
- sts_model_path: str = "sentence-transformers/all-mpnet-base-v2",
146
- watermark_fn: Literal["fourier", "square"] = "fourier",
147
- watermarker: Optional[Watermarker] = None,
148
- sts_model: Optional[SentenceTransformer] = None,
149
- device: str = detect_gpu(),
150
- num_beam_groups: int = 4,
151
- beams_per_group: int = 2,
152
- diversity_penalty: float = 0.5,
153
- STS_scale:float = 2.0,
154
- use_tqdm: bool = False,
155
- stop_at_double_newline: bool = True, # if True, will stop generation at the first double newline. Prevent repeated paraphrasing of the same text.
156
- ) -> List[str]:
157
- if watermark_fn == 'fourier':
158
- watermarkingFnClass = WatermarkingFnFourier
159
- elif watermark_fn == 'square':
160
- watermarkingFnClass = WatermarkingFnSquare
161
- else:
162
- raise ValueError("Invalid watermarking function")
163
-
164
- if watermarker is None:
165
- assert model_path is not None, "model_path must be provided if watermarker is not passed"
166
- global waterfall_cached_watermarking_model
167
-
168
- if isinstance(waterfall_cached_watermarking_model, PreTrainedModel) and waterfall_cached_watermarking_model.name_or_path != model_path:
169
- device = waterfall_cached_watermarking_model.device.type
170
- waterfall_cached_watermarking_model = None
171
- gc.collect()
172
- if device == "cuda":
173
- torch.cuda.empty_cache()
174
- elif device == "mps":
175
- torch.mps.empty_cache()
176
-
177
- if waterfall_cached_watermarking_model is None:
178
- waterfall_cached_watermarking_model = AutoModelForCausalLM.from_pretrained(
179
- model_path,
180
- torch_dtype=torch_dtype,
181
- device_map=device,
182
- )
183
- model = waterfall_cached_watermarking_model
184
- tokenizer = AutoTokenizer.from_pretrained(model_path)
185
-
186
- watermarker = Watermarker(tokenizer=tokenizer, model=model, id=id, kappa=kappa, k_p=k_p, watermarkingFnClass=watermarkingFnClass)
187
- else:
188
- tokenizer = watermarker.tokenizer
189
- device = watermarker.model.device
190
- id = watermarker.id
191
-
192
- if id is None:
193
- raise Exception("ID or Watermarker class must be passed to watermark_texts.")
194
-
195
- if sts_model is None:
196
- assert sts_model_path is not None, "sts_model_path must be provided if sts_model is not passed"
197
- sts_model = SentenceTransformer(sts_model_path, device=device)
198
-
199
- T_ws = []
200
-
201
- for T_o in tqdm(T_os, desc="Watermarking texts", disable=not use_tqdm):
202
- if stop_at_double_newline and "\n\n" in T_o:
203
- logging.warning("Text contains \\n\\n and stop_at_double_newline is set to True, replacing all \\n\\n in text.")
204
- T_o = T_o.replace("\n\n", " ") # replace double newlines with space
205
- T_w = watermark(
206
- T_o,
207
- watermarker = watermarker,
208
- sts_model = sts_model,
209
- num_beam_groups = num_beam_groups,
210
- beams_per_group = beams_per_group,
211
- diversity_penalty = diversity_penalty,
212
- STS_scale = STS_scale,
213
- stop_strings=["\n\n"] if stop_at_double_newline else None,
214
- )
215
- T_ws.append(T_w)
216
-
217
- return T_ws
218
-
219
227
  def pretty_print(
220
228
  T_o: str, T_w: str,
221
229
  sts_score: float,
@@ -303,13 +311,15 @@ def main():
303
311
  sts_model = SentenceTransformer(sts_model_name, device=device)
304
312
 
305
313
  T_ws = watermark_texts(
306
- T_os, id, k_p, kappa,
307
- watermarker=watermarker, sts_model=sts_model,
314
+ T_os,
315
+ id=id, k_p=k_p, kappa=kappa,
316
+ watermarker=watermarker,
317
+ sts_model=sts_model,
308
318
  beams_per_group=beams_per_group,
309
319
  num_beam_groups=num_beam_groups,
310
320
  diversity_penalty=diversity_penalty,
311
321
  STS_scale=STS_scale,
312
- use_tqdm=True
322
+ use_tqdm=True,
313
323
  )
314
324
 
315
325
  # watermarker = Watermarker(tokenizer=tokenizer, model=None, id=id, k_p=k_p, watermarkingFnClass=watermarkingFnClass) # If only verifying the watermark, do not need to instantiate the model
@@ -320,7 +330,7 @@ def main():
320
330
  # in an IDE or something else without terminal size
321
331
  try:
322
332
  column_size = os.get_terminal_size().columns
323
- except OSError as ose:
333
+ except OSError:
324
334
  column_size = 80
325
335
 
326
336
  print("=" * column_size)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: waterfall
3
- Version: 0.1.7
3
+ Version: 0.2.0
4
4
  Summary: Scalable Framework for Robust Text Watermarking and Provenance for LLMs
5
5
  Project-URL: Homepage, https://github.com/aoi3142/Waterfall
6
6
  Project-URL: Issues, https://github.com/aoi3142/Waterfall/issues
@@ -15,7 +15,7 @@ Requires-Dist: numpy>=2.0.0
15
15
  Requires-Dist: scipy>=1.13.0
16
16
  Requires-Dist: sentence-transformers>=3.0.0
17
17
  Requires-Dist: torch>=2.3.0
18
- Requires-Dist: transformers>=4.43.1
18
+ Requires-Dist: transformers<4.55.0,>=4.43.1
19
19
  Description-Content-Type: text/markdown
20
20
 
21
21
  # Waterfall: Scalable Framework for Robust Text Watermarking and Provenance for LLMs [EMNLP 2024 Main Long]
@@ -0,0 +1,12 @@
1
+ waterfall/WatermarkerBase.py,sha256=Jg9jnSU_JYfyCJAyAUqSBcE_IoAByneLC1yKfPEv4mo,20774
2
+ waterfall/WatermarkingFn.py,sha256=-b-kGRdL0a7eKRqJmcHPAR_rCjxQYnsg1Ne6bTwBc1I,1931
3
+ waterfall/WatermarkingFnFourier.py,sha256=QYayAQYwi1dQkDIyqmvhU568VhrVYTVy47HkI8F8SZs,1358
4
+ waterfall/WatermarkingFnSquare.py,sha256=2PAO05DdKT02npo7GDf_82D520nP7kGAWK6H4E4JMt4,1638
5
+ waterfall/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ waterfall/permute.py,sha256=uYKdmn4pGvjB6hInInLGxFIF6vt507lqJ_qe-ST1PFE,2783
7
+ waterfall/watermark.py,sha256=_foE_9K1xBdJhTZq2EOeuanl8X04A3wh_F02_1m0LMA,14468
8
+ waterfall-0.2.0.dist-info/METADATA,sha256=AFjq_Ox5oTq2V2a7q80CEAGlYtjpYfOtvEQTSp98w-A,8722
9
+ waterfall-0.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ waterfall-0.2.0.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
+ waterfall-0.2.0.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
+ waterfall-0.2.0.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- waterfall/WatermarkerBase.py,sha256=NrDo4yJ4gnliTHH3LZemALpU_L-MCaPapevV1YnRHuE,12999
2
- waterfall/WatermarkingFn.py,sha256=-b-kGRdL0a7eKRqJmcHPAR_rCjxQYnsg1Ne6bTwBc1I,1931
3
- waterfall/WatermarkingFnFourier.py,sha256=QYayAQYwi1dQkDIyqmvhU568VhrVYTVy47HkI8F8SZs,1358
4
- waterfall/WatermarkingFnSquare.py,sha256=2PAO05DdKT02npo7GDf_82D520nP7kGAWK6H4E4JMt4,1638
5
- waterfall/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- waterfall/permute.py,sha256=RwxOHFhx_VSOhhFwy5s79YgwTUBkfW2-LCCXYR3VT2o,2582
7
- waterfall/watermark.py,sha256=W5jYGqYGOXXO-KLPKzJoin5zC_Xb6Xk9BzsAA9-LKXA,13494
8
- waterfall-0.1.7.dist-info/METADATA,sha256=-QVkPeyZWXdPHr_SvhvIFyCNy3G2GuzHKPmg9w8Z1-I,8714
9
- waterfall-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
- waterfall-0.1.7.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
- waterfall-0.1.7.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
- waterfall-0.1.7.dist-info/RECORD,,