waterfall 0.1.6__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: waterfall
3
- Version: 0.1.6
3
+ Version: 0.2.0
4
4
  Summary: Scalable Framework for Robust Text Watermarking and Provenance for LLMs
5
5
  Project-URL: Homepage, https://github.com/aoi3142/Waterfall
6
6
  Project-URL: Issues, https://github.com/aoi3142/Waterfall/issues
@@ -15,7 +15,7 @@ Requires-Dist: numpy>=2.0.0
15
15
  Requires-Dist: scipy>=1.13.0
16
16
  Requires-Dist: sentence-transformers>=3.0.0
17
17
  Requires-Dist: torch>=2.3.0
18
- Requires-Dist: transformers>=4.43.1
18
+ Requires-Dist: transformers<4.55.0,>=4.43.1
19
19
  Description-Content-Type: text/markdown
20
20
 
21
21
  # Waterfall: Scalable Framework for Robust Text Watermarking and Provenance for LLMs [EMNLP 2024 Main Long]
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "waterfall"
7
- version = "0.1.6"
7
+ version = "0.2.0"
8
8
  authors = [
9
9
  { name = "Xinyuan Niu", email="aperture@outlook.sg" }
10
10
  ]
@@ -24,7 +24,7 @@ dependencies = [
24
24
  "scipy>=1.13.0",
25
25
  "sentence-transformers>=3.0.0",
26
26
  "torch>=2.3.0",
27
- "transformers>=4.43.1",
27
+ "transformers>=4.43.1,<4.55.0",
28
28
  ]
29
29
 
30
30
  [project.urls]
@@ -3,4 +3,4 @@ numpy>=2.0.0
3
3
  scipy>=1.13.0
4
4
  sentence-transformers>=3.0.0
5
5
  torch>=2.3.0
6
- transformers>=4.43.1
6
+ transformers>=4.43.1,<4.55.0
@@ -0,0 +1,487 @@
1
+ import gc
2
+ import logging
3
+ import os
4
+ import time
5
+ from collections import defaultdict
6
+ from functools import partial
7
+ from multiprocessing import Pool
8
+ from typing import List, Tuple, Optional
9
+ from itertools import repeat
10
+
11
+ import numpy as np
12
+ import torch
13
+ from scipy.sparse import csr_matrix, vstack
14
+ from tqdm import tqdm
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase, BatchEncoding
18
+ from transformers.generation.logits_process import LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper
19
+ from transformers.generation.configuration_utils import GenerationConfig
20
+
21
+ from waterfall.permute import Permute
22
+ from waterfall.WatermarkingFn import WatermarkingFn
23
+ from waterfall.WatermarkingFnFourier import WatermarkingFnFourier
24
+
25
+ class PerturbationProcessor(LogitsProcessor):
26
+ def __init__(self,
27
+ N : int = 32000, # Vocab size
28
+ id : int = 0, # Watermark ID
29
+ ) -> None:
30
+
31
+ self.id = id
32
+ self.N = N
33
+ self.init_token_count = None
34
+ self.phi = torch.zeros(N)
35
+ self.n_gram = 2
36
+
37
+ self.skip_watermark = False
38
+
39
+ self.permute = Permute(self.N)
40
+
41
+ def reset(self, n_gram : int = 2) -> None:
42
+ self.n_gram = n_gram
43
+ self.init_token_count = None
44
+ if torch.allclose(self.phi,torch.median(self.phi)):
45
+ self.skip_watermark = True
46
+ logging.warning(f"Generating without watermark as watermarking function is flat")
47
+ else:
48
+ self.skip_watermark = False
49
+
50
+ def set_phi(self, phi : np.ndarray) -> None:
51
+ self.phi = torch.from_numpy(phi)
52
+
53
+ def __call__(self, input_ids: torch.LongTensor,
54
+ scores: torch.FloatTensor) -> torch.FloatTensor:
55
+
56
+ if self.skip_watermark:
57
+ return scores
58
+
59
+ if self.init_token_count is None:
60
+ self.init_token_count = input_ids.shape[1]
61
+
62
+ # Insufficient tokens generated for n-gram
63
+ if self.init_token_count + self.n_gram - 1 > input_ids.shape[1]:
64
+ return scores
65
+
66
+ # using numpy as PyTorch tensors doesn't hash properly for rng and dict key
67
+ prev_tokens = input_ids[:,-self.n_gram+1:].cpu().numpy()
68
+
69
+ permutations = (
70
+ self.permute.get_permutation(prev_tokens[i,:], self.id, cache=True)
71
+ for i in range(prev_tokens.shape[0])
72
+ )
73
+ perturbations = torch.stack([
74
+ self.phi[permutation] for permutation in permutations
75
+ ])
76
+ scores[:,:self.N] += perturbations.to(device=scores.device, dtype=scores.dtype)
77
+ return scores
78
+
79
+ def indices_to_counts(N : int, dtype : np.dtype, indices : np.ndarray) -> csr_matrix:
80
+ counts = csr_matrix([np.bincount(j, minlength=N).astype(dtype) for j in indices])
81
+ return counts
82
+
83
+ class Watermarker:
84
+ def __init__(self,
85
+ tokenizer : Optional[PreTrainedTokenizerBase | str] = None,
86
+ model : Optional[PreTrainedModel | str] = None,
87
+ id : int = 0,
88
+ kappa : float = 6,
89
+ k_p : int = 1,
90
+ n_gram : int = 2,
91
+ watermarkingFnClass = WatermarkingFnFourier,
92
+ device = None,
93
+ ) -> None:
94
+ assert kappa >= 0, f"kappa must be >= 0, value provided is {kappa}"
95
+
96
+ self.id = id
97
+ self.k_p = k_p
98
+ self.n_gram = n_gram
99
+ self.kappa = kappa
100
+
101
+ if tokenizer is None:
102
+ if isinstance(model, str):
103
+ self.tokenizer = AutoTokenizer.from_pretrained(model)
104
+ elif isinstance(model, PreTrainedModel):
105
+ self.tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
106
+ else:
107
+ raise NotImplementedError("tokenizer must be provided or model must be a string or PreTrainedModel")
108
+ elif isinstance(tokenizer, str):
109
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
110
+ else:
111
+ self.tokenizer = tokenizer
112
+ self.tokenizer.pad_token = self.tokenizer.eos_token
113
+ self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
114
+
115
+ self.N = self.tokenizer.vocab_size
116
+
117
+ self.logits_processor = PerturbationProcessor(N = self.N, id = self.id)
118
+
119
+ if isinstance(model, str):
120
+ self.load_model(model, device_map=device)
121
+ else:
122
+ self.model = model
123
+
124
+ assert (self.model is None) or isinstance(self.model, PreTrainedModel), f"model must be a transformers model, value provided is {type(self.model)}" # argument order for tokenizer and model were swapped since the original code
125
+
126
+ self.compute_phi(watermarkingFnClass)
127
+
128
+ def load_model(self, model_name_or_path : str, device_map : str = "auto"):
129
+ self.model = AutoModelForCausalLM.from_pretrained(
130
+ model_name_or_path,
131
+ device_map=device_map,
132
+ )
133
+
134
+ def compute_phi(self, watermarkingFnClass = WatermarkingFnFourier) -> None:
135
+ self.watermarking_fn: WatermarkingFn = watermarkingFnClass(id = self.id, k_p = self.k_p, N = self.N, kappa = self.kappa)
136
+ self.phi = self.watermarking_fn.phi
137
+
138
+ self.logits_processor.set_phi(self.phi)
139
+
140
+ # Format prompt(s) into chat template
141
+ def format_prompt(
142
+ self,
143
+ T_os : str | List[str],
144
+ system_prompt : Optional[str] = None,
145
+ assistant_prefill : Optional[str | List[str]] = "",
146
+ ) -> str | List[str]:
147
+ if isinstance(system_prompt, str):
148
+ _system_prompt = {"role":"system", "content":system_prompt}
149
+ is_single = isinstance(T_os, str)
150
+ if is_single:
151
+ T_os = [T_os]
152
+ if not isinstance(assistant_prefill, list):
153
+ assistant_prefill = repeat(assistant_prefill, len(T_os))
154
+ else:
155
+ assert len(assistant_prefill) == len(T_os), "Length of assistant_prefill must match length of T_os"
156
+ formatted_prompts = []
157
+ for T_o, prefill in zip(T_os, assistant_prefill):
158
+ formatted_prompt : str = self.tokenizer.apply_chat_template(
159
+ [
160
+ _system_prompt,
161
+ {"role":"user", "content":T_o},
162
+ ], tokenize=False, add_generation_prompt = True)
163
+ if prefill is not None:
164
+ formatted_prompt += prefill
165
+ formatted_prompts.append(formatted_prompt)
166
+ if is_single:
167
+ return formatted_prompts[0]
168
+ return formatted_prompts
169
+
170
+ # Find the largest batch size that fits in GPU memory
171
+ def find_largest_batch_size(
172
+ self,
173
+ tokd_inputs : List[BatchEncoding],
174
+ logits_processor : List[LogitsProcessor] = [],
175
+ **kwargs,
176
+ ):
177
+ longest_idx = np.argmax([tokd_input["input_ids"].shape[-1] for tokd_input in tokd_inputs])
178
+ if "generation_config" in kwargs:
179
+ generation_config = GenerationConfig(**kwargs["generation_config"].to_dict()) # copy
180
+ max_new_tokens = generation_config.max_new_tokens
181
+ else:
182
+ generation_config = GenerationConfig(**kwargs)
183
+ max_new_tokens = kwargs.get("max_new_tokens", 2048)
184
+ generation_config.update(max_new_tokens=1)
185
+ input_ids = tokd_inputs[longest_idx]["input_ids"]
186
+ input_ids = torch.zeros(
187
+ (1, max_new_tokens + input_ids.shape[-1] - 1),
188
+ dtype=input_ids.dtype,
189
+ device=self.model.device
190
+ )
191
+ max_batch_size = 1
192
+ with torch.no_grad():
193
+ while max_batch_size < min(16, len(tokd_inputs)):
194
+ torch.cuda.empty_cache()
195
+ try:
196
+ _ = self.model.generate(
197
+ input_ids=input_ids,
198
+ attention_mask=torch.ones_like(input_ids),
199
+ logits_processor=logits_processor,
200
+ generation_config=generation_config,
201
+ pad_token_id=self.tokenizer.eos_token_id,
202
+ tokenizer=self.tokenizer,
203
+ )
204
+ max_batch_size = input_ids.shape[0]
205
+ except RuntimeError as e:
206
+ if "CUDA out of memory" in str(e):
207
+ break
208
+ else:
209
+ raise e
210
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
211
+ torch.cuda.empty_cache()
212
+ return max_batch_size
213
+
214
+ def generate(
215
+ self,
216
+ prompts : Optional[str | List[str]] = None,
217
+ tokd_inputs : Optional[torch.Tensor | List[torch.Tensor] | BatchEncoding | List[BatchEncoding]] = None,
218
+ n_gram : Optional[int] = None,
219
+ return_text : bool = True,
220
+ return_tokens : bool = False,
221
+ return_scores : bool = False,
222
+ use_tqdm : bool = False,
223
+ batched_generate : bool = True,
224
+ **kwargs # Other generate parameters
225
+ ) -> List[str] | dict: # Returns flattened list of query x beam
226
+
227
+ assert self.model is not None, "Model is not loaded. Please load the model before generating text."
228
+
229
+ is_single = isinstance(prompts, str) or isinstance(tokd_inputs, torch.Tensor)
230
+ if is_single:
231
+ prompts = [prompts] if prompts is not None else None
232
+ tokd_inputs = [tokd_inputs] if tokd_inputs is not None else None
233
+
234
+ if n_gram is None:
235
+ n_gram = self.n_gram
236
+ if tokd_inputs is None:
237
+ assert prompts is not None, "Either prompt or tokd_input must be provided."
238
+ tokd_inputs = [self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False) for prompt in prompts]
239
+
240
+ # If tokd_input is a tensor, convert it to a BatchEncoding
241
+ squeezed_tokd_inputs = []
242
+ for tokd_input in tokd_inputs:
243
+ if isinstance(tokd_input, torch.Tensor):
244
+ input_ids = tokd_input
245
+ attention_mask = torch.ones_like(tokd_input)
246
+ else:
247
+ input_ids = tokd_input["input_ids"]
248
+ attention_mask = tokd_input["attention_mask"]
249
+ if input_ids.ndim == 2:
250
+ input_ids = input_ids.squeeze()
251
+ attention_mask = attention_mask.squeeze()
252
+ squeezed_tokd_inputs.append(BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask}))
253
+ tokd_inputs = squeezed_tokd_inputs
254
+
255
+ logits_processor = []
256
+ # Ensure top_k and top_p happens before watermarking
257
+ if "generation_config" in kwargs:
258
+ generation_config: GenerationConfig = kwargs["generation_config"]
259
+ top_k = generation_config.top_k
260
+ top_p = generation_config.top_p
261
+ generation_config.update(top_p=1.0)
262
+ else:
263
+ top_k = kwargs.pop("top_k", None)
264
+ top_p = kwargs.pop("top_p", None)
265
+
266
+ if top_k is not None and top_k != 0:
267
+ logits_processor.append(TopKLogitsWarper(top_k))
268
+ if top_p is not None and top_p < 1.0:
269
+ logits_processor.append(TopPLogitsWarper(top_p))
270
+ if self.kappa != 0:
271
+ logits_processor.append(self.logits_processor)
272
+
273
+ if batched_generate and len(tokd_inputs) >= 8:
274
+ max_batch_size = self.find_largest_batch_size(tokd_inputs, logits_processor=logits_processor, **kwargs)
275
+ else:
276
+ max_batch_size = 1
277
+
278
+ # Group inputs by token length
279
+ if max_batch_size > 1:
280
+ tokd_inputs_order = sorted(range(len(tokd_inputs)), key=lambda i: tokd_inputs[i]["input_ids"].shape[-1])
281
+ tokd_inputs = [tokd_inputs[i] for i in tokd_inputs_order]
282
+ else:
283
+ tokd_inputs_order = range(len(tokd_inputs))
284
+ tokd_input_batches = []
285
+ for i in range(0, len(tokd_inputs), max_batch_size):
286
+ batch = self.tokenizer.pad(tokd_inputs[i:i+max_batch_size], padding=True, padding_side="left").to(self.model.device, non_blocking=True)
287
+ tokd_input_batches.append(batch)
288
+ torch.cuda.synchronize()
289
+
290
+ outputs = []
291
+ with torch.no_grad():
292
+ bar = tqdm(total=len(tokd_inputs), desc="Generating text", disable=not use_tqdm)
293
+ for tokd_input_batch in tokd_input_batches:
294
+ self.logits_processor.reset(n_gram)
295
+ output = self.model.generate(
296
+ **tokd_input_batch,
297
+ logits_processor=logits_processor,
298
+ pad_token_id=self.tokenizer.eos_token_id,
299
+ tokenizer=self.tokenizer,
300
+ **kwargs
301
+ )
302
+ output = output[:,tokd_input_batch["input_ids"].shape[-1]:].to("cpu", non_blocking=True)
303
+ outputs.append(output)
304
+ bar.update(tokd_input_batch["input_ids"].shape[0])
305
+ torch.cuda.synchronize()
306
+ outputs = [j for i in outputs for j in i] # Flatten the list of outputs
307
+
308
+ # Restore original ordering
309
+ if max_batch_size > 1:
310
+ reordered_outputs = [None] * len(outputs)
311
+ num_return_sequences = len(outputs) // len(tokd_inputs)
312
+ for i, idx in enumerate(tokd_inputs_order):
313
+ reordered_outputs[idx * num_return_sequences:(idx + 1) * num_return_sequences] = outputs[i * num_return_sequences:(i + 1) * num_return_sequences]
314
+ outputs = reordered_outputs
315
+
316
+ return_dict = {}
317
+
318
+ if return_scores:
319
+ cumulative_token_count = self.get_cumulative_token_count(self.id, outputs, n_gram = n_gram, return_dense=False)
320
+ cumulative_token_count = vstack([i[0] for i in cumulative_token_count], format="csr")
321
+ q_score, _, _ = self.watermarking_fn.q(cumulative_token_count, k_p = [self.k_p], use_tqdm=False)
322
+ return_dict["q_score"] = q_score
323
+
324
+ if return_tokens:
325
+ return_dict["tokens"] = outputs
326
+
327
+ if return_text:
328
+ decoded_output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
329
+ decoded_output = [i.strip() for i in decoded_output]
330
+ return_dict["text"] = decoded_output
331
+
332
+ if is_single:
333
+ for k, v in return_dict.items():
334
+ return_dict[k] = v[0]
335
+
336
+ if return_text and len(return_dict) == 1:
337
+ return decoded_output
338
+
339
+ return return_dict
340
+
341
+ def get_cumulative_token_count(
342
+ self,
343
+ ids : List[int] | int,
344
+ all_tokens : List[torch.Tensor] | torch.Tensor | List[np.ndarray] | np.ndarray | List[List[int]] | List[int],
345
+ n_gram : int = 2,
346
+ return_unshuffled_indices : bool = False,
347
+ use_tqdm : bool = False,
348
+ return_dense : bool = True,
349
+ batch_size : int = 2**8,
350
+ ) -> List[csr_matrix] | List[np.ndarray] | Tuple[List[csr_matrix], List[List[np.ndarray]]] | Tuple[List[np.ndarray], List[List[np.ndarray]]]:
351
+ if isinstance(ids, int):
352
+ ids = [ids]
353
+ if isinstance(all_tokens[0], int) or (isinstance(all_tokens, (np.ndarray, torch.Tensor)) and all_tokens.ndim == 1):
354
+ all_tokens = [all_tokens]
355
+ all_tokens = list(map(lambda x: x.cpu().numpy() if isinstance(x, torch.Tensor) else x, all_tokens))
356
+ max_length = max(map(len, all_tokens))
357
+ window = n_gram - 1
358
+
359
+ # Collect all unique seeds for psuedo-random number generation
360
+ key_index_dict = defaultdict(set)
361
+ all_keys = []
362
+ for i, tokens in enumerate(tqdm(all_tokens, desc="Collecting unique n-grams", disable=not use_tqdm)):
363
+ all_keys.append([])
364
+ for j in range(window, len(tokens)):
365
+ prev_token = tuple(tokens[j-window:j])
366
+ t = tokens[j]
367
+ if t >= self.N:
368
+ break
369
+ key_index_dict[prev_token].add(t)
370
+ all_keys[i].append((prev_token, t))
371
+ key_index_dict = {k:tuple(v) for k,v in key_index_dict.items()}
372
+
373
+ use_mp = len(all_tokens) > batch_size * 4
374
+ if use_mp:
375
+ p = Pool(len(os.sched_getaffinity(0))-1)
376
+ pool_map = partial(p.imap, chunksize=batch_size)
377
+ else:
378
+ pool_map = map
379
+
380
+ # Generate permutations for all unique seeds
381
+ permutations = pool_map(
382
+ partial(self.logits_processor.permute.get_unshuffled_indices, ids),
383
+ key_index_dict.items())
384
+ permutations = tqdm(permutations, total=len(key_index_dict), desc="Getting permutations", disable=not use_tqdm)
385
+ for k, value in zip(key_index_dict.keys(), permutations):
386
+ key_index_dict[k] = value
387
+
388
+ # Assign indices to unshuffled_indices
389
+ unshuffled_indices: List[np.ndarray] = [] # [text x id x length]
390
+ for keys in tqdm(all_keys, desc="Assigning indices", disable=not use_tqdm):
391
+ if len(keys) == 0:
392
+ unshuffled_indices.append(np.zeros((len(ids), 0), dtype=np.min_scalar_type(self.N)))
393
+ else:
394
+ unshuffled_indices.append(np.stack([key_index_dict[key][t] for key, t in keys]).T) # [id x length]
395
+
396
+ # Convert indices to counts
397
+ cumulative_token_count = pool_map(
398
+ partial(indices_to_counts, self.N, np.min_scalar_type(max_length)),
399
+ unshuffled_indices
400
+ )
401
+ cumulative_token_count = list(tqdm(cumulative_token_count, total=len(unshuffled_indices), desc="Counting tokens", disable=not use_tqdm))
402
+
403
+ if use_mp:
404
+ p.close()
405
+ p.join()
406
+
407
+ if return_dense:
408
+ cumulative_token_count = list(map(lambda x: x.toarray(), cumulative_token_count))
409
+
410
+ if return_unshuffled_indices:
411
+ return cumulative_token_count, unshuffled_indices
412
+ return cumulative_token_count
413
+
414
+ def verify(
415
+ self,
416
+ text : str | List[str],
417
+ id: Optional[int | List[int]] = None,
418
+ k_p : Optional[int | List[int]] = None,
419
+ return_ranking : bool = False,
420
+ return_extracted_k_p : bool = False,
421
+ return_counts : bool = False,
422
+ return_unshuffled_indices : bool = False,
423
+ use_tqdm : bool = False,
424
+ batch_size : int = 2**8,
425
+ ) -> np.ndarray | dict:
426
+ begin_time = time.time()
427
+
428
+ if id is None:
429
+ id = self.id
430
+
431
+ if isinstance(text, str):
432
+ texts = [text]
433
+ else:
434
+ texts = text
435
+
436
+ tokens = [np.array(self.tokenizer.encode(text, add_special_tokens=False), dtype=np.uint32) for text in tqdm(texts, desc="Tokenizing", disable=not use_tqdm)]
437
+
438
+ if isinstance(id, int):
439
+ ids = [id]
440
+ else:
441
+ ids = id
442
+
443
+ if k_p is None:
444
+ k_p = self.k_p
445
+
446
+ if isinstance(k_p, int):
447
+ k_ps = [k_p]
448
+ else:
449
+ k_ps = k_p
450
+
451
+ # Get cummulative token counts
452
+ start_time = time.time()
453
+ results = self.get_cumulative_token_count(ids, tokens, self.n_gram, return_unshuffled_indices, use_tqdm=use_tqdm, return_dense=False, batch_size=batch_size)
454
+ gc.collect()
455
+ if return_unshuffled_indices:
456
+ results, unshuffled_indices = results
457
+ results = vstack(results, format="csr")
458
+ if use_tqdm:
459
+ tqdm.write(f"Cummulative token counts done in {time.time() - start_time:.2f} seconds")
460
+
461
+ # Calculate Q score via dot product
462
+ start_time = time.time()
463
+ q_score, ranking, k_p_extracted = self.watermarking_fn.q(results, k_p = k_ps, batch = batch_size, use_tqdm = use_tqdm)
464
+ q_score, ranking = [i.reshape(-1, len(ids), i.shape[-1]) for i in (q_score, ranking)] # [text x ids x k_p for i in (score, rank)]
465
+ k_p_extracted = k_p_extracted.reshape(-1, len(ids)) # [text x ids]
466
+ if use_tqdm:
467
+ tqdm.write(f"Q score calculated in {time.time() - start_time:.2f} seconds")
468
+
469
+ res = q_score # [text x ids x k_p]
470
+
471
+ if return_ranking or return_extracted_k_p or return_counts or return_unshuffled_indices:
472
+ res = {
473
+ "q_score": q_score, # [text x ids x k_p]
474
+ }
475
+ if return_ranking:
476
+ res["ranking"] = ranking # [text x ids x k_p]
477
+ if return_extracted_k_p:
478
+ res["k_p_extracted"] = k_p_extracted # [text x ids]
479
+ if return_counts:
480
+ res["counts"] = results # [text x ids x k_p]
481
+ if return_unshuffled_indices:
482
+ res["unshuffled_indices"] = unshuffled_indices # [text x ids x length]
483
+
484
+ if use_tqdm:
485
+ tqdm.write(f"Total time taken for verify: {time.time() - begin_time:.2f} seconds")
486
+
487
+ return res
@@ -49,16 +49,23 @@ class Permute:
49
49
  size_per_permutation_in_bytes = N * self.dtype.itemsize
50
50
  cache_size = int(psutil.virtual_memory().total * 0.02 / size_per_permutation_in_bytes) # 2% of total memory
51
51
  self.permutations.capacity = cache_size
52
+ self.no_permutation = np.arange(self.N, dtype=self.dtype)
53
+
54
+ def _permute(self, key):
55
+ return np.random.RandomState(key).permutation(self.N).astype(self.dtype)
52
56
 
53
57
  def get_permutation(self, prev_tok, id : int, cache : bool = False) -> np.ndarray:
58
+ # Skip special tokens
59
+ if any((i >= self.N for i in prev_tok)):
60
+ return self.no_permutation
54
61
  key = (id, *prev_tok)
55
62
  if cache:
56
63
  permutation = self.permutations.get(key)
57
64
  if permutation is None:
58
- permutation = np.random.RandomState(key).permutation(self.N).astype(self.dtype)
65
+ permutation = self._permute(key)
59
66
  self.permutations.put(key, permutation)
60
67
  else:
61
- permutation = np.random.RandomState(key).permutation(self.N).astype(self.dtype)
68
+ permutation = self._permute(key)
62
69
  return permutation
63
70
 
64
71
  def get_unshuffled_indices(self, ids, args) -> dict[int, np.ndarray]: