waterfall 0.1.7__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,19 +6,24 @@ from collections import defaultdict
6
6
  from functools import partial
7
7
  from multiprocessing import Pool
8
8
  from typing import List, Tuple, Optional
9
+ from itertools import repeat
9
10
 
10
11
  import numpy as np
11
12
  import torch
12
13
  from scipy.sparse import csr_matrix, vstack
13
14
  from tqdm import tqdm
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
16
  from transformers.modeling_utils import PreTrainedModel
15
- from transformers.tokenization_utils_base import PreTrainedTokenizerBase
17
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase, BatchEncoding
16
18
  from transformers.generation.logits_process import LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper
19
+ from transformers.generation.configuration_utils import GenerationConfig
17
20
 
18
21
  from waterfall.permute import Permute
19
22
  from waterfall.WatermarkingFn import WatermarkingFn
20
23
  from waterfall.WatermarkingFnFourier import WatermarkingFnFourier
21
24
 
25
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
26
+
22
27
  class PerturbationProcessor(LogitsProcessor):
23
28
  def __init__(self,
24
29
  N : int = 32000, # Vocab size
@@ -28,7 +33,7 @@ class PerturbationProcessor(LogitsProcessor):
28
33
  self.id = id
29
34
  self.N = N
30
35
  self.init_token_count = None
31
- self.phi = np.ones(N)
36
+ self.phi = torch.zeros(N)
32
37
  self.n_gram = 2
33
38
 
34
39
  self.skip_watermark = False
@@ -38,14 +43,14 @@ class PerturbationProcessor(LogitsProcessor):
38
43
  def reset(self, n_gram : int = 2) -> None:
39
44
  self.n_gram = n_gram
40
45
  self.init_token_count = None
41
- if np.allclose(self.phi,np.median(self.phi)):
46
+ if torch.allclose(self.phi,torch.median(self.phi)):
42
47
  self.skip_watermark = True
43
48
  logging.warning(f"Generating without watermark as watermarking function is flat")
44
49
  else:
45
50
  self.skip_watermark = False
46
51
 
47
52
  def set_phi(self, phi : np.ndarray) -> None:
48
- self.phi = phi
53
+ self.phi = torch.from_numpy(phi)
49
54
 
50
55
  def __call__(self, input_ids: torch.LongTensor,
51
56
  scores: torch.FloatTensor) -> torch.FloatTensor:
@@ -60,12 +65,17 @@ class PerturbationProcessor(LogitsProcessor):
60
65
  if self.init_token_count + self.n_gram - 1 > input_ids.shape[1]:
61
66
  return scores
62
67
 
68
+ # using numpy as PyTorch tensors doesn't hash properly for rng and dict key
63
69
  prev_tokens = input_ids[:,-self.n_gram+1:].cpu().numpy()
64
- permutations = [self.permute.get_permutation(prev_tokens[i,:], self.id, cache=True) for i in range(prev_tokens.shape[0])]
65
70
 
66
- scores[:,:self.N] += torch.tensor(self.phi[permutations],
67
- device=scores.device,
68
- dtype=scores.dtype)
71
+ permutations = (
72
+ self.permute.get_permutation(prev_tokens[i,:], self.id, cache=True)
73
+ for i in range(prev_tokens.shape[0])
74
+ )
75
+ perturbations = torch.stack([
76
+ self.phi[permutation] for permutation in permutations
77
+ ])
78
+ scores[:,:self.N] += perturbations.to(device=scores.device, dtype=scores.dtype)
69
79
  return scores
70
80
 
71
81
  def indices_to_counts(N : int, dtype : np.dtype, indices : np.ndarray) -> csr_matrix:
@@ -74,99 +84,258 @@ def indices_to_counts(N : int, dtype : np.dtype, indices : np.ndarray) -> csr_ma
74
84
 
75
85
  class Watermarker:
76
86
  def __init__(self,
77
- tokenizer : PreTrainedTokenizerBase,
78
- model : Optional[PreTrainedModel] = None,
87
+ tokenizer : Optional[PreTrainedTokenizerBase | str] = None,
88
+ model : Optional[PreTrainedModel | str] = None,
79
89
  id : int = 0,
80
90
  kappa : float = 6,
81
91
  k_p : int = 1,
82
92
  n_gram : int = 2,
83
- watermarkingFnClass = WatermarkingFnFourier
93
+ watermarkingFnClass = WatermarkingFnFourier,
94
+ device = None,
84
95
  ) -> None:
85
96
  assert kappa >= 0, f"kappa must be >= 0, value provided is {kappa}"
86
97
 
87
- assert (model is None) or isinstance(model, PreTrainedModel), f"model must be a transformers model, value provided is {type(model)}" # argument order for tokenizer and model were swapped since the original code
88
-
89
- self.tokenizer = tokenizer
90
- self.model = model
91
98
  self.id = id
92
99
  self.k_p = k_p
93
100
  self.n_gram = n_gram
94
101
  self.kappa = kappa
95
102
 
103
+ if tokenizer is None:
104
+ if isinstance(model, str):
105
+ self.tokenizer = AutoTokenizer.from_pretrained(model)
106
+ elif isinstance(model, PreTrainedModel):
107
+ self.tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
108
+ else:
109
+ raise NotImplementedError("tokenizer must be provided or model must be a string or PreTrainedModel")
110
+ elif isinstance(tokenizer, str):
111
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
112
+ else:
113
+ self.tokenizer = tokenizer
114
+ self.tokenizer.pad_token = self.tokenizer.eos_token
115
+ self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
116
+
96
117
  self.N = self.tokenizer.vocab_size
97
- self.logits_processor = PerturbationProcessor(N = self.N, id = id)
118
+
119
+ self.logits_processor = PerturbationProcessor(N = self.N, id = self.id)
120
+
121
+ if isinstance(model, str):
122
+ self.load_model(model, device_map=device)
123
+ else:
124
+ self.model = model
125
+
126
+ assert (self.model is None) or isinstance(self.model, PreTrainedModel), f"model must be a transformers model, value provided is {type(self.model)}" # argument order for tokenizer and model were swapped since the original code
98
127
 
99
128
  self.compute_phi(watermarkingFnClass)
100
129
 
130
+ def load_model(self, model_name_or_path : str, device_map : str = "auto"):
131
+ self.model = AutoModelForCausalLM.from_pretrained(
132
+ model_name_or_path,
133
+ device_map=device_map,
134
+ )
135
+
101
136
  def compute_phi(self, watermarkingFnClass = WatermarkingFnFourier) -> None:
102
- self.watermarking_fn: WatermarkingFn = watermarkingFnClass(id = id, k_p = self.k_p, N = self.N, kappa = self.kappa)
137
+ self.watermarking_fn: WatermarkingFn = watermarkingFnClass(id = self.id, k_p = self.k_p, N = self.N, kappa = self.kappa)
103
138
  self.phi = self.watermarking_fn.phi
104
139
 
105
140
  self.logits_processor.set_phi(self.phi)
106
141
 
142
+ # Format prompt(s) into chat template
143
+ def format_prompt(
144
+ self,
145
+ T_os : str | List[str],
146
+ system_prompt : Optional[str] = None,
147
+ assistant_prefill : Optional[str | List[str]] = "",
148
+ ) -> str | List[str]:
149
+ if isinstance(system_prompt, str):
150
+ _system_prompt = {"role":"system", "content":system_prompt}
151
+ is_single = isinstance(T_os, str)
152
+ if is_single:
153
+ T_os = [T_os]
154
+ if not isinstance(assistant_prefill, list):
155
+ assistant_prefill = repeat(assistant_prefill, len(T_os))
156
+ else:
157
+ assert len(assistant_prefill) == len(T_os), "Length of assistant_prefill must match length of T_os"
158
+ formatted_prompts = []
159
+ for T_o, prefill in zip(T_os, assistant_prefill):
160
+ formatted_prompt : str = self.tokenizer.apply_chat_template(
161
+ [
162
+ _system_prompt,
163
+ {"role":"user", "content":T_o},
164
+ ], tokenize=False, add_generation_prompt = True)
165
+ if prefill is not None:
166
+ formatted_prompt += prefill
167
+ formatted_prompts.append(formatted_prompt)
168
+ if is_single:
169
+ return formatted_prompts[0]
170
+ return formatted_prompts
171
+
172
+ # Find the largest batch size that fits in GPU memory
173
+ def find_largest_batch_size(
174
+ self,
175
+ tokd_inputs : List[BatchEncoding],
176
+ logits_processor : List[LogitsProcessor] = [],
177
+ **kwargs,
178
+ ):
179
+ longest_idx = np.argmax([tokd_input["input_ids"].shape[-1] for tokd_input in tokd_inputs])
180
+ if "generation_config" in kwargs:
181
+ generation_config = GenerationConfig(**kwargs["generation_config"].to_dict()) # copy
182
+ max_new_tokens = generation_config.max_new_tokens
183
+ else:
184
+ generation_config = GenerationConfig(**kwargs)
185
+ max_new_tokens = kwargs.get("max_new_tokens", 2048)
186
+ generation_config.update(max_new_tokens=1)
187
+ input_ids = tokd_inputs[longest_idx]["input_ids"]
188
+ input_ids = torch.zeros(
189
+ (1, max_new_tokens + input_ids.shape[-1] - 1),
190
+ dtype=input_ids.dtype,
191
+ device=self.model.device
192
+ )
193
+ max_batch_size = 1
194
+ with torch.no_grad():
195
+ while max_batch_size < min(16, len(tokd_inputs)):
196
+ torch.cuda.empty_cache()
197
+ try:
198
+ _ = self.model.generate(
199
+ input_ids=input_ids,
200
+ attention_mask=torch.ones_like(input_ids),
201
+ logits_processor=logits_processor,
202
+ generation_config=generation_config,
203
+ pad_token_id=self.tokenizer.eos_token_id,
204
+ tokenizer=self.tokenizer,
205
+ )
206
+ max_batch_size = input_ids.shape[0]
207
+ except RuntimeError as e:
208
+ if "CUDA out of memory" in str(e):
209
+ break
210
+ else:
211
+ raise e
212
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
213
+ torch.cuda.empty_cache()
214
+ return max_batch_size
215
+
107
216
  def generate(
108
217
  self,
109
- prompt : Optional[str] = None,
110
- tokd_input : Optional[torch.Tensor] = None,
218
+ prompts : Optional[str | List[str]] = None,
219
+ tokd_inputs : Optional[torch.Tensor | List[torch.Tensor] | BatchEncoding | List[BatchEncoding]] = None,
111
220
  n_gram : Optional[int] = None,
112
- max_new_tokens : int = 1000,
113
- return_text : bool =True,
114
- return_tokens : bool =False,
115
- return_scores : bool =False,
116
- do_sample : bool =True,
117
- **kwargs
118
- ) -> List[str] | dict:
221
+ return_text : bool = True,
222
+ return_tokens : bool = False,
223
+ return_scores : bool = False,
224
+ use_tqdm : bool = False,
225
+ batched_generate : bool = True,
226
+ **kwargs # Other generate parameters
227
+ ) -> List[str] | dict: # Returns flattened list of query x beam
119
228
 
120
229
  assert self.model is not None, "Model is not loaded. Please load the model before generating text."
121
230
 
231
+ is_single = isinstance(prompts, str) or isinstance(tokd_inputs, torch.Tensor)
232
+ if is_single:
233
+ prompts = [prompts] if prompts is not None else None
234
+ tokd_inputs = [tokd_inputs] if tokd_inputs is not None else None
235
+
122
236
  if n_gram is None:
123
237
  n_gram = self.n_gram
124
- if tokd_input is None:
125
- assert prompt is not None, "Either prompt or tokd_input must be provided."
126
- tokd_input = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
127
- tokd_input = tokd_input.to(self.model.device)
238
+ if tokd_inputs is None:
239
+ assert prompts is not None, "Either prompt or tokd_input must be provided."
240
+ tokd_inputs = [self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False) for prompt in prompts]
241
+
242
+ # If tokd_input is a tensor, convert it to a BatchEncoding
243
+ squeezed_tokd_inputs = []
244
+ for tokd_input in tokd_inputs:
245
+ if isinstance(tokd_input, torch.Tensor):
246
+ input_ids = tokd_input
247
+ attention_mask = torch.ones_like(tokd_input)
248
+ else:
249
+ input_ids = tokd_input["input_ids"]
250
+ attention_mask = tokd_input["attention_mask"]
251
+ if input_ids.ndim == 2:
252
+ input_ids = input_ids.squeeze()
253
+ attention_mask = attention_mask.squeeze()
254
+ squeezed_tokd_inputs.append(BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask}))
255
+ tokd_inputs = squeezed_tokd_inputs
256
+
128
257
  logits_processor = []
129
- if "top_k" in kwargs and kwargs["top_k"] is not None and kwargs["top_k"] != 0:
130
- logits_processor.append(TopKLogitsWarper(kwargs.pop("top_k")))
131
- if "top_p" in kwargs and kwargs["top_p"] is not None and kwargs["top_p"] < 1.0:
132
- logits_processor.append(TopPLogitsWarper(kwargs.pop("top_p")))
258
+ # Ensure top_k and top_p happens before watermarking
259
+ if "generation_config" in kwargs:
260
+ generation_config: GenerationConfig = kwargs["generation_config"]
261
+ top_k = generation_config.top_k
262
+ top_p = generation_config.top_p
263
+ generation_config.update(top_p=1.0)
264
+ else:
265
+ top_k = kwargs.pop("top_k", None)
266
+ top_p = kwargs.pop("top_p", None)
267
+
268
+ if top_k is not None and top_k != 0:
269
+ logits_processor.append(TopKLogitsWarper(top_k))
270
+ if top_p is not None and top_p < 1.0:
271
+ logits_processor.append(TopPLogitsWarper(top_p))
133
272
  if self.kappa != 0:
134
273
  logits_processor.append(self.logits_processor)
135
274
 
275
+ if batched_generate and len(tokd_inputs) >= 8:
276
+ max_batch_size = self.find_largest_batch_size(tokd_inputs, logits_processor=logits_processor, **kwargs)
277
+ else:
278
+ max_batch_size = 1
279
+
280
+ # Group inputs by token length
281
+ if max_batch_size > 1:
282
+ tokd_inputs_order = sorted(range(len(tokd_inputs)), key=lambda i: tokd_inputs[i]["input_ids"].shape[-1])
283
+ tokd_inputs = [tokd_inputs[i] for i in tokd_inputs_order]
284
+ else:
285
+ tokd_inputs_order = range(len(tokd_inputs))
286
+ tokd_input_batches = []
287
+ for i in range(0, len(tokd_inputs), max_batch_size):
288
+ batch = self.tokenizer.pad(tokd_inputs[i:i+max_batch_size], padding=True, padding_side="left").to(self.model.device, non_blocking=True)
289
+ tokd_input_batches.append(batch)
290
+ torch.cuda.synchronize()
291
+
292
+ outputs = []
136
293
  with torch.no_grad():
137
- self.logits_processor.reset(n_gram)
138
- output = self.model.generate(
139
- **tokd_input,
140
- max_new_tokens=max_new_tokens,
141
- do_sample=do_sample,
142
- logits_processor=logits_processor,
143
- pad_token_id=self.tokenizer.eos_token_id,
144
- tokenizer=self.tokenizer,
145
- **kwargs
146
- )
147
- output = output[:,tokd_input["input_ids"].shape[-1]:].cpu()
294
+ bar = tqdm(total=len(tokd_inputs), desc="Generating text", disable=not use_tqdm)
295
+ for tokd_input_batch in tokd_input_batches:
296
+ self.logits_processor.reset(n_gram)
297
+ output = self.model.generate(
298
+ **tokd_input_batch,
299
+ logits_processor=logits_processor,
300
+ pad_token_id=self.tokenizer.eos_token_id,
301
+ tokenizer=self.tokenizer,
302
+ **kwargs
303
+ )
304
+ output = output[:,tokd_input_batch["input_ids"].shape[-1]:].to("cpu", non_blocking=True)
305
+ outputs.append(output)
306
+ bar.update(tokd_input_batch["input_ids"].shape[0])
307
+ torch.cuda.synchronize()
308
+ outputs = [j for i in outputs for j in i] # Flatten the list of outputs
309
+
310
+ # Restore original ordering
311
+ if max_batch_size > 1:
312
+ reordered_outputs = [None] * len(outputs)
313
+ num_return_sequences = len(outputs) // len(tokd_inputs)
314
+ for i, idx in enumerate(tokd_inputs_order):
315
+ reordered_outputs[idx * num_return_sequences:(idx + 1) * num_return_sequences] = outputs[i * num_return_sequences:(i + 1) * num_return_sequences]
316
+ outputs = reordered_outputs
148
317
 
149
318
  return_dict = {}
150
319
 
151
320
  if return_scores:
152
- cumulative_token_count = self.get_cumulative_token_count(self.id, output, n_gram = n_gram, return_dense=False)
321
+ cumulative_token_count = self.get_cumulative_token_count(self.id, outputs, n_gram = n_gram, return_dense=False)
153
322
  cumulative_token_count = vstack([i[0] for i in cumulative_token_count], format="csr")
154
323
  q_score, _, _ = self.watermarking_fn.q(cumulative_token_count, k_p = [self.k_p], use_tqdm=False)
155
- return_dict["q_score"] = q_score[:,0]
324
+ return_dict["q_score"] = q_score
156
325
 
157
326
  if return_tokens:
158
- return_dict["tokens"] = output
327
+ return_dict["tokens"] = outputs
159
328
 
160
329
  if return_text:
161
- decoded_output = self.tokenizer.batch_decode(output, skip_special_tokens=True)
330
+ decoded_output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
162
331
  decoded_output = [i.strip() for i in decoded_output]
163
332
  return_dict["text"] = decoded_output
164
333
 
165
- if len(output) == 1:
334
+ if is_single:
166
335
  for k, v in return_dict.items():
167
336
  return_dict[k] = v[0]
168
337
 
169
- if return_text and len(return_dict) == 0:
338
+ if return_text and len(return_dict) == 1:
170
339
  return decoded_output
171
340
 
172
341
  return return_dict
waterfall/permute.py CHANGED
@@ -49,16 +49,23 @@ class Permute:
49
49
  size_per_permutation_in_bytes = N * self.dtype.itemsize
50
50
  cache_size = int(psutil.virtual_memory().total * 0.02 / size_per_permutation_in_bytes) # 2% of total memory
51
51
  self.permutations.capacity = cache_size
52
+ self.no_permutation = np.arange(self.N, dtype=self.dtype)
53
+
54
+ def _permute(self, key):
55
+ return np.random.RandomState(key).permutation(self.N).astype(self.dtype)
52
56
 
53
57
  def get_permutation(self, prev_tok, id : int, cache : bool = False) -> np.ndarray:
58
+ # Skip special tokens
59
+ if any((i >= self.N for i in prev_tok)):
60
+ return self.no_permutation
54
61
  key = (id, *prev_tok)
55
62
  if cache:
56
63
  permutation = self.permutations.get(key)
57
64
  if permutation is None:
58
- permutation = np.random.RandomState(key).permutation(self.N).astype(self.dtype)
65
+ permutation = self._permute(key)
59
66
  self.permutations.put(key, permutation)
60
67
  else:
61
- permutation = np.random.RandomState(key).permutation(self.N).astype(self.dtype)
68
+ permutation = self._permute(key)
62
69
  return permutation
63
70
 
64
71
  def get_unshuffled_indices(self, ids, args) -> dict[int, np.ndarray]:
waterfall/watermark.py CHANGED
@@ -8,15 +8,13 @@ from typing import List, Literal, Optional, Tuple
8
8
 
9
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
10
  from transformers.modeling_utils import PreTrainedModel
11
+ from transformers.generation.configuration_utils import GenerationConfig
11
12
  from sentence_transformers import SentenceTransformer
12
- from tqdm.auto import tqdm
13
13
 
14
14
  from waterfall.WatermarkingFnFourier import WatermarkingFnFourier
15
15
  from waterfall.WatermarkingFnSquare import WatermarkingFnSquare
16
16
  from waterfall.WatermarkerBase import Watermarker
17
17
 
18
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
-
20
18
  PROMPT = (
21
19
  "Paraphrase the user provided text while preserving semantic similarity. "
22
20
  "Do not include any other sentences in the response, such as explanations of the paraphrasing. "
@@ -24,7 +22,7 @@ PROMPT = (
24
22
  )
25
23
  PRE_PARAPHRASED = "Here is a paraphrased version of the text while preserving the semantic similarity:\n\n"
26
24
 
27
- waterfall_cached_watermarking_model = None # Global variable to cache the watermarking model
25
+ waterfall_cached_watermarking_model: PreTrainedModel | None = None # Global variable to cache the watermarking model
28
26
 
29
27
  def detect_gpu() -> str:
30
28
  """
@@ -41,47 +39,137 @@ def detect_gpu() -> str:
41
39
  else:
42
40
  return 'cpu'
43
41
 
44
- def watermark(
45
- T_o: str,
46
- watermarker: Watermarker,
47
- sts_model: SentenceTransformer,
42
+ def watermark_texts(
43
+ T_os: List[str],
44
+ id: Optional[int] = None,
45
+ k_p: int = 1,
46
+ kappa: float = 2.0,
47
+ model_path: Optional[str] = "meta-llama/Llama-3.1-8B-Instruct",
48
+ sts_model_path: Optional[str] = "sentence-transformers/all-mpnet-base-v2",
49
+ watermark_fn: Literal["fourier", "square"] = "fourier",
50
+ watermarker: Optional[Watermarker] = None,
51
+ sts_model: Optional[SentenceTransformer] = None,
52
+ device: str = detect_gpu(),
53
+ STS_scale: float = 2.0,
54
+ use_tqdm: bool = False,
55
+ do_sample: bool = False,
56
+ temperature: Optional[float] = None,
57
+ top_p: Optional[float] = None,
58
+ max_new_tokens: Optional[int] = None,
48
59
  num_beam_groups: int = 4,
49
60
  beams_per_group: int = 2,
50
- STS_scale: float = 2.0,
51
61
  diversity_penalty: float = 0.5,
52
- max_new_tokens: Optional[int] = None,
53
- **kwargs
54
- ) -> str:
55
- paraphrasing_prompt = watermarker.tokenizer.apply_chat_template(
56
- [
57
- {"role":"system", "content":PROMPT},
58
- {"role":"user", "content":T_o},
59
- ], tokenize=False, add_generation_prompt = True) + PRE_PARAPHRASED
62
+ stop_at_double_newline: bool = True, # if True, will stop generation at the first double newline. Prevent repeated paraphrasing of the same text.
63
+ ) -> List[str]:
64
+ if watermark_fn == 'fourier':
65
+ watermarkingFnClass = WatermarkingFnFourier
66
+ elif watermark_fn == 'square':
67
+ watermarkingFnClass = WatermarkingFnSquare
68
+ else:
69
+ raise ValueError("Invalid watermarking function")
70
+
71
+ # Check if watermarker/model/tokenizer are loaded
72
+ if watermarker is None:
73
+ assert model_path is not None, "model_path must be provided if watermarker is not passed"
74
+ assert id is not None, "id must be provided if watermarker is not passed"
75
+ global waterfall_cached_watermarking_model
76
+
77
+ if isinstance(waterfall_cached_watermarking_model, PreTrainedModel) and waterfall_cached_watermarking_model.name_or_path != model_path:
78
+ device = waterfall_cached_watermarking_model.device.type
79
+ waterfall_cached_watermarking_model = None
80
+ gc.collect()
81
+ if device == "cuda":
82
+ torch.cuda.empty_cache()
83
+ elif device == "mps":
84
+ torch.mps.empty_cache()
85
+
86
+ if waterfall_cached_watermarking_model is None:
87
+ model = model_path
88
+ else:
89
+ model = waterfall_cached_watermarking_model
90
+
91
+ watermarker = Watermarker(model=model, id=id, kappa=kappa, k_p=k_p, watermarkingFnClass=watermarkingFnClass)
92
+ else:
93
+ device = watermarker.model.device.type
94
+ id = watermarker.id
95
+ waterfall_cached_watermarking_model = watermarker.model
96
+
97
+ # Check if sts model is loaded
98
+ if sts_model is None:
99
+ assert sts_model_path is not None, "sts_model_path must be provided if sts_model is not passed"
100
+ sts_model = SentenceTransformer(sts_model_path, device=device)
101
+
102
+ # Replace all \n\n in source text if stop_at_double_newline is True
103
+ # Models tend to generate \n\n before endlessly repeating itself, so we want to stop the model from doing that
104
+ if stop_at_double_newline:
105
+ for i in range(len(T_os)):
106
+ if "\n\n" in T_os[i]:
107
+ logging.warning(f"Text idx {i} contains \\n\\n and stop_at_double_newline is set to True, replacing all \\n\\n in text.")
108
+ T_os[i] = T_os[i].replace("\n\n", " ") # replace double newlines with space
109
+
110
+ # Add system prompt and prefill, and format into appropriate chat format
111
+ formatted_T_os = watermarker.format_prompt(
112
+ T_os,
113
+ system_prompt=PROMPT,
114
+ assistant_prefill=PRE_PARAPHRASED,
115
+ )
116
+
117
+ if max_new_tokens is None:
118
+ max_input_len = max(len(p) for p in formatted_T_os)
119
+ max_new_tokens = max_input_len
120
+
121
+ if do_sample:
122
+ assert (do_sample and temperature is not None and top_p is not None and num_beam_groups == 1 and beams_per_group == 1), \
123
+ "do_sample=True requires temperature, top_p, num_beam_groups=1 and beams_per_group=1"
124
+ else: # Using beam search
125
+ assert (not do_sample and temperature is None and top_p is None and num_beam_groups >= 1 and beams_per_group >= 1), \
126
+ "do_sample=False requires temperature=None, top_p=None, num_beam_groups>=1 and beams_per_group>=1"
127
+
128
+ eos_token_id = watermarker.tokenizer.eos_token_id
129
+ # add "\n\n" tokens to eos_token_id list
130
+ if stop_at_double_newline:
131
+ eos_token_id = [eos_token_id]
132
+ # llama tokenizer's .vocab() has weird symbols and doesn't work with GenerationConfig's stop_strings, so we have to brute force check all tokens
133
+ for token_id,string in enumerate(watermarker.tokenizer.batch_decode(torch.arange(watermarker.tokenizer.vocab_size).unsqueeze(1))):
134
+ if "\n\n" in string:
135
+ eos_token_id.append(token_id)
136
+
137
+ generation_config = GenerationConfig(
138
+ max_new_tokens=max_new_tokens,
139
+ do_sample=do_sample,
140
+ temperature=temperature,
141
+ top_p=top_p,
142
+ num_beam_groups=num_beam_groups,
143
+ num_beams=num_beam_groups * beams_per_group,
144
+ diversity_penalty=diversity_penalty,
145
+ eos_token_id=eos_token_id,
146
+ num_return_sequences=num_beam_groups * beams_per_group,
147
+ )
60
148
 
61
149
  watermarked = watermarker.generate(
62
- paraphrasing_prompt,
63
- return_scores = True,
64
- max_new_tokens = int(len(paraphrasing_prompt) * 1.5) if max_new_tokens is None else max_new_tokens,
65
- do_sample = False, temperature=None, top_p=None,
66
- num_beams = num_beam_groups * beams_per_group,
67
- num_beam_groups = num_beam_groups,
68
- num_return_sequences = num_beam_groups * beams_per_group,
69
- diversity_penalty = diversity_penalty,
70
- **kwargs,
71
- )
150
+ prompts=formatted_T_os,
151
+ return_text=True,
152
+ return_scores=True,
153
+ use_tqdm=use_tqdm,
154
+ generation_config=generation_config,
155
+ )
156
+ T_ws = watermarked["text"]
157
+ # Reshape T_ws to Queries X Beams
158
+ num_beams = num_beam_groups * beams_per_group
159
+ T_ws = [T_ws[i * num_beams:(i + 1) * num_beams] for i in range(len(T_os))]
72
160
 
73
161
  # Select best paraphrasing based on q_score and semantic similarity
74
- sts_scores = STS_scorer(T_o, watermarked["text"], sts_model)
75
- selection_score = sts_scores * STS_scale + torch.from_numpy(watermarked["q_score"])
76
- selection = torch.argmax(selection_score)
162
+ sts_scores = STS_scorer_batch(T_os, T_ws, sts_model)
163
+ selection_scores = sts_scores * STS_scale + torch.from_numpy(watermarked["q_score"]).reshape(-1, num_beams)
164
+ selections = torch.argmax(selection_scores, dim = -1)
77
165
 
78
- T_w = watermarked["text"][selection]
166
+ T_ws = [T_w[selection] for T_w, selection in zip(T_ws, selections)]
79
167
 
80
- return T_w
168
+ return T_ws
81
169
 
82
- def verify_texts(texts: List[str], id: int,
83
- watermarker: Optional[Watermarker] = None,
84
- k_p: Optional[int] = None,
170
+ def verify_texts(texts: List[str], id: int,
171
+ watermarker: Optional[Watermarker] = None,
172
+ k_p: Optional[int] = None,
85
173
  model_path: Optional[str] = "meta-llama/Llama-3.1-8B-Instruct",
86
174
  return_extracted_k_p: bool = False
87
175
  ) -> np.ndarray | Tuple[np.ndarray,np.ndarray]:
@@ -89,9 +177,8 @@ def verify_texts(texts: List[str], id: int,
89
177
 
90
178
  if watermarker is None:
91
179
  assert model_path is not None, "model_path must be provided if watermarker is not passed"
92
- tokenizer = AutoTokenizer.from_pretrained(model_path)
93
- watermarker = Watermarker(tokenizer=tokenizer)
94
-
180
+ watermarker = Watermarker(tokenizer=model_path)
181
+
95
182
  if k_p is None:
96
183
  k_p = watermarker.k_p
97
184
 
@@ -135,87 +222,6 @@ def STS_scorer(
135
222
  cos_sim = cos_sim.item()
136
223
  return cos_sim
137
224
 
138
- def watermark_texts(
139
- T_os: List[str],
140
- id: Optional[int] = None,
141
- k_p: int = 1,
142
- kappa: float = 2.0,
143
- model_path: str = "meta-llama/Llama-3.1-8B-Instruct",
144
- torch_dtype: torch.dtype = torch.bfloat16,
145
- sts_model_path: str = "sentence-transformers/all-mpnet-base-v2",
146
- watermark_fn: Literal["fourier", "square"] = "fourier",
147
- watermarker: Optional[Watermarker] = None,
148
- sts_model: Optional[SentenceTransformer] = None,
149
- device: str = detect_gpu(),
150
- num_beam_groups: int = 4,
151
- beams_per_group: int = 2,
152
- diversity_penalty: float = 0.5,
153
- STS_scale:float = 2.0,
154
- use_tqdm: bool = False,
155
- stop_at_double_newline: bool = True, # if True, will stop generation at the first double newline. Prevent repeated paraphrasing of the same text.
156
- ) -> List[str]:
157
- if watermark_fn == 'fourier':
158
- watermarkingFnClass = WatermarkingFnFourier
159
- elif watermark_fn == 'square':
160
- watermarkingFnClass = WatermarkingFnSquare
161
- else:
162
- raise ValueError("Invalid watermarking function")
163
-
164
- if watermarker is None:
165
- assert model_path is not None, "model_path must be provided if watermarker is not passed"
166
- global waterfall_cached_watermarking_model
167
-
168
- if isinstance(waterfall_cached_watermarking_model, PreTrainedModel) and waterfall_cached_watermarking_model.name_or_path != model_path:
169
- device = waterfall_cached_watermarking_model.device.type
170
- waterfall_cached_watermarking_model = None
171
- gc.collect()
172
- if device == "cuda":
173
- torch.cuda.empty_cache()
174
- elif device == "mps":
175
- torch.mps.empty_cache()
176
-
177
- if waterfall_cached_watermarking_model is None:
178
- waterfall_cached_watermarking_model = AutoModelForCausalLM.from_pretrained(
179
- model_path,
180
- torch_dtype=torch_dtype,
181
- device_map=device,
182
- )
183
- model = waterfall_cached_watermarking_model
184
- tokenizer = AutoTokenizer.from_pretrained(model_path)
185
-
186
- watermarker = Watermarker(tokenizer=tokenizer, model=model, id=id, kappa=kappa, k_p=k_p, watermarkingFnClass=watermarkingFnClass)
187
- else:
188
- tokenizer = watermarker.tokenizer
189
- device = watermarker.model.device
190
- id = watermarker.id
191
-
192
- if id is None:
193
- raise Exception("ID or Watermarker class must be passed to watermark_texts.")
194
-
195
- if sts_model is None:
196
- assert sts_model_path is not None, "sts_model_path must be provided if sts_model is not passed"
197
- sts_model = SentenceTransformer(sts_model_path, device=device)
198
-
199
- T_ws = []
200
-
201
- for T_o in tqdm(T_os, desc="Watermarking texts", disable=not use_tqdm):
202
- if stop_at_double_newline and "\n\n" in T_o:
203
- logging.warning("Text contains \\n\\n and stop_at_double_newline is set to True, replacing all \\n\\n in text.")
204
- T_o = T_o.replace("\n\n", " ") # replace double newlines with space
205
- T_w = watermark(
206
- T_o,
207
- watermarker = watermarker,
208
- sts_model = sts_model,
209
- num_beam_groups = num_beam_groups,
210
- beams_per_group = beams_per_group,
211
- diversity_penalty = diversity_penalty,
212
- STS_scale = STS_scale,
213
- stop_strings=["\n\n"] if stop_at_double_newline else None,
214
- )
215
- T_ws.append(T_w)
216
-
217
- return T_ws
218
-
219
225
  def pretty_print(
220
226
  T_o: str, T_w: str,
221
227
  sts_score: float,
@@ -303,13 +309,15 @@ def main():
303
309
  sts_model = SentenceTransformer(sts_model_name, device=device)
304
310
 
305
311
  T_ws = watermark_texts(
306
- T_os, id, k_p, kappa,
307
- watermarker=watermarker, sts_model=sts_model,
312
+ T_os,
313
+ id=id, k_p=k_p, kappa=kappa,
314
+ watermarker=watermarker,
315
+ sts_model=sts_model,
308
316
  beams_per_group=beams_per_group,
309
317
  num_beam_groups=num_beam_groups,
310
318
  diversity_penalty=diversity_penalty,
311
319
  STS_scale=STS_scale,
312
- use_tqdm=True
320
+ use_tqdm=True,
313
321
  )
314
322
 
315
323
  # watermarker = Watermarker(tokenizer=tokenizer, model=None, id=id, k_p=k_p, watermarkingFnClass=watermarkingFnClass) # If only verifying the watermark, do not need to instantiate the model
@@ -320,7 +328,7 @@ def main():
320
328
  # in an IDE or something else without terminal size
321
329
  try:
322
330
  column_size = os.get_terminal_size().columns
323
- except OSError as ose:
331
+ except OSError:
324
332
  column_size = 80
325
333
 
326
334
  print("=" * column_size)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: waterfall
3
- Version: 0.1.7
3
+ Version: 0.2.1
4
4
  Summary: Scalable Framework for Robust Text Watermarking and Provenance for LLMs
5
5
  Project-URL: Homepage, https://github.com/aoi3142/Waterfall
6
6
  Project-URL: Issues, https://github.com/aoi3142/Waterfall/issues
@@ -15,7 +15,7 @@ Requires-Dist: numpy>=2.0.0
15
15
  Requires-Dist: scipy>=1.13.0
16
16
  Requires-Dist: sentence-transformers>=3.0.0
17
17
  Requires-Dist: torch>=2.3.0
18
- Requires-Dist: transformers>=4.43.1
18
+ Requires-Dist: transformers<4.55.0,>=4.43.1
19
19
  Description-Content-Type: text/markdown
20
20
 
21
21
  # Waterfall: Scalable Framework for Robust Text Watermarking and Provenance for LLMs [EMNLP 2024 Main Long]
@@ -0,0 +1,12 @@
1
+ waterfall/WatermarkerBase.py,sha256=A2VRfsnBfz6-8DSL2NKQZdM1OLI0sQ73qjYaV6rIgJ0,20822
2
+ waterfall/WatermarkingFn.py,sha256=-b-kGRdL0a7eKRqJmcHPAR_rCjxQYnsg1Ne6bTwBc1I,1931
3
+ waterfall/WatermarkingFnFourier.py,sha256=QYayAQYwi1dQkDIyqmvhU568VhrVYTVy47HkI8F8SZs,1358
4
+ waterfall/WatermarkingFnSquare.py,sha256=2PAO05DdKT02npo7GDf_82D520nP7kGAWK6H4E4JMt4,1638
5
+ waterfall/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ waterfall/permute.py,sha256=uYKdmn4pGvjB6hInInLGxFIF6vt507lqJ_qe-ST1PFE,2783
7
+ waterfall/watermark.py,sha256=IbH5r3oqjtKztDVryfDTr_NDn-CLZHow0S8nAEtZmdc,14420
8
+ waterfall-0.2.1.dist-info/METADATA,sha256=Mzyp7Nw395RLCN3wnzp2StEpKZEN2erb5BvCOd5Z-4I,8722
9
+ waterfall-0.2.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ waterfall-0.2.1.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
+ waterfall-0.2.1.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
+ waterfall-0.2.1.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- waterfall/WatermarkerBase.py,sha256=NrDo4yJ4gnliTHH3LZemALpU_L-MCaPapevV1YnRHuE,12999
2
- waterfall/WatermarkingFn.py,sha256=-b-kGRdL0a7eKRqJmcHPAR_rCjxQYnsg1Ne6bTwBc1I,1931
3
- waterfall/WatermarkingFnFourier.py,sha256=QYayAQYwi1dQkDIyqmvhU568VhrVYTVy47HkI8F8SZs,1358
4
- waterfall/WatermarkingFnSquare.py,sha256=2PAO05DdKT02npo7GDf_82D520nP7kGAWK6H4E4JMt4,1638
5
- waterfall/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- waterfall/permute.py,sha256=RwxOHFhx_VSOhhFwy5s79YgwTUBkfW2-LCCXYR3VT2o,2582
7
- waterfall/watermark.py,sha256=W5jYGqYGOXXO-KLPKzJoin5zC_Xb6Xk9BzsAA9-LKXA,13494
8
- waterfall-0.1.7.dist-info/METADATA,sha256=-QVkPeyZWXdPHr_SvhvIFyCNy3G2GuzHKPmg9w8Z1-I,8714
9
- waterfall-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
- waterfall-0.1.7.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
- waterfall-0.1.7.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
- waterfall-0.1.7.dist-info/RECORD,,