waterfall 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,6 +24,17 @@ from waterfall.WatermarkingFnFourier import WatermarkingFnFourier
24
24
 
25
25
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
26
26
 
27
+ # Check transformers version
28
+ import transformers
29
+ from packaging import version
30
+ # Group beam search is shifted to transformers-community package in 4.57.0
31
+ use_custom_group_beam_search = version.parse(transformers.__version__) >= version.parse("4.57.0")
32
+ # Set model loading kwargs based on transformers version
33
+ if version.parse(transformers.__version__) >= version.parse("4.56.0"):
34
+ model_from_pretrained_kwargs = {"dtype": "auto"}
35
+ else:
36
+ model_from_pretrained_kwargs = {"torch_dtype": torch.bfloat16}
37
+
27
38
  class PerturbationProcessor(LogitsProcessor):
28
39
  def __init__(self,
29
40
  N : int = 32000, # Vocab size
@@ -134,7 +145,7 @@ class Watermarker:
134
145
  self.model = AutoModelForCausalLM.from_pretrained(
135
146
  model_name_or_path,
136
147
  device_map=device_map,
137
- torch_dtype=dtype,
148
+ **model_from_pretrained_kwargs,
138
149
  )
139
150
 
140
151
  def compute_phi(self, watermarkingFnClass = WatermarkingFnFourier) -> None:
@@ -228,6 +239,7 @@ class Watermarker:
228
239
  use_tqdm : bool = False,
229
240
  batched_generate : bool = True,
230
241
  discard_incomplete : bool = True,
242
+ logits_processor = [],
231
243
  **kwargs # Other generate parameters
232
244
  ) -> List[str] | dict: # Returns flattened list of query x beam
233
245
 
@@ -259,7 +271,6 @@ class Watermarker:
259
271
  squeezed_tokd_inputs.append(BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask}))
260
272
  tokd_inputs = squeezed_tokd_inputs
261
273
 
262
- logits_processor = []
263
274
  # Ensure top_k and top_p happens before watermarking
264
275
  if "generation_config" in kwargs:
265
276
  generation_config: GenerationConfig = kwargs["generation_config"]
@@ -275,10 +286,13 @@ class Watermarker:
275
286
  top_k = kwargs.pop("top_k", None)
276
287
  top_p = kwargs.pop("top_p", None)
277
288
  temperature = kwargs.pop("temperature", 1.0)
278
- num_beams = kwargs.pop("num_beams", 1)
279
- diversity_penalty = kwargs.pop("diversity_penalty", None)
289
+ num_beams = kwargs.get("num_beams", 1)
290
+ diversity_penalty = kwargs.get("diversity_penalty", None)
280
291
  if num_beams <= 1:
281
292
  kwargs["diversity_penalty"] = None
293
+ if use_custom_group_beam_search:
294
+ kwargs["custom_generate"]="transformers-community/group-beam-search"
295
+ kwargs["trust_remote_code"]=True
282
296
 
283
297
  if num_beams > 1 and temperature is not None and temperature != 1.0:
284
298
  logits_processor.append(TemperatureLogitsWarper(float(temperature)))
@@ -351,10 +365,6 @@ class Watermarker:
351
365
  decoded_output = [i.strip() for i in decoded_output]
352
366
  return_dict["text"] = decoded_output
353
367
 
354
- if is_single:
355
- for k, v in return_dict.items():
356
- return_dict[k] = v[0]
357
-
358
368
  if return_text and len(return_dict) == 1:
359
369
  return decoded_output
360
370
 
waterfall/watermark.py CHANGED
@@ -15,6 +15,14 @@ from waterfall.WatermarkingFnFourier import WatermarkingFnFourier
15
15
  from waterfall.WatermarkingFnSquare import WatermarkingFnSquare
16
16
  from waterfall.WatermarkerBase import Watermarker
17
17
 
18
+ # Check transformers version
19
+ import transformers
20
+ from packaging import version
21
+ if version.parse(transformers.__version__) >= version.parse("4.56.0"):
22
+ model_from_pretrained_kwargs = {"dtype": "auto"}
23
+ else:
24
+ model_from_pretrained_kwargs = {"torch_dtype": torch.bfloat16}
25
+
18
26
  PROMPT = (
19
27
  "Paraphrase the user provided text while preserving semantic similarity. "
20
28
  "Do not include any other sentences in the response, such as explanations of the paraphrasing. "
@@ -306,8 +314,8 @@ def main():
306
314
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
307
315
  model = AutoModelForCausalLM.from_pretrained(
308
316
  model_name_or_path,
309
- torch_dtype=torch.bfloat16,
310
317
  device_map=device,
318
+ **model_from_pretrained_kwargs,
311
319
  )
312
320
 
313
321
  watermarker = Watermarker(tokenizer=tokenizer, model=model, id=id, kappa=kappa, k_p=k_p, watermarkingFnClass=watermarkingFnClass)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: waterfall
3
- Version: 0.2.10
3
+ Version: 0.2.12
4
4
  Summary: Scalable Framework for Robust Text Watermarking and Provenance for LLMs
5
5
  Project-URL: Homepage, https://github.com/aoi3142/Waterfall
6
6
  Project-URL: Issues, https://github.com/aoi3142/Waterfall/issues
@@ -15,7 +15,7 @@ Requires-Dist: numpy>=1.25.0
15
15
  Requires-Dist: scipy>=1.13.0
16
16
  Requires-Dist: sentence-transformers>=3.0.0
17
17
  Requires-Dist: torch>=2.3.0
18
- Requires-Dist: transformers<4.57.0,>=4.43.1
18
+ Requires-Dist: transformers>=4.43.1
19
19
  Description-Content-Type: text/markdown
20
20
 
21
21
  # Waterfall: Scalable Framework for Robust Text Watermarking and Provenance for LLMs [EMNLP 2024 Main Long]
@@ -0,0 +1,12 @@
1
+ waterfall/WatermarkerBase.py,sha256=6O_S78dD3Jha2OkJK2u3euwCH93i-mTiYYGXosPDMig,22632
2
+ waterfall/WatermarkingFn.py,sha256=-b-kGRdL0a7eKRqJmcHPAR_rCjxQYnsg1Ne6bTwBc1I,1931
3
+ waterfall/WatermarkingFnFourier.py,sha256=QYayAQYwi1dQkDIyqmvhU568VhrVYTVy47HkI8F8SZs,1358
4
+ waterfall/WatermarkingFnSquare.py,sha256=2PAO05DdKT02npo7GDf_82D520nP7kGAWK6H4E4JMt4,1638
5
+ waterfall/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ waterfall/permute.py,sha256=uYKdmn4pGvjB6hInInLGxFIF6vt507lqJ_qe-ST1PFE,2783
7
+ waterfall/watermark.py,sha256=fvscFoSbM51YUuDaOmrOKGvwXO25VMgGJKTfAeeKCaA,14817
8
+ waterfall-0.2.12.dist-info/METADATA,sha256=TBoeAFK8qkG-jIRi-OeKq4GbtFTDQkaUKyKuFNeDQHo,8760
9
+ waterfall-0.2.12.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
10
+ waterfall-0.2.12.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
+ waterfall-0.2.12.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
+ waterfall-0.2.12.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,12 +0,0 @@
1
- waterfall/WatermarkerBase.py,sha256=1IvGo1rz1Ec-NW8rQ9bSC8KNdHawu4gl4CNsNncce7Q,22046
2
- waterfall/WatermarkingFn.py,sha256=-b-kGRdL0a7eKRqJmcHPAR_rCjxQYnsg1Ne6bTwBc1I,1931
3
- waterfall/WatermarkingFnFourier.py,sha256=QYayAQYwi1dQkDIyqmvhU568VhrVYTVy47HkI8F8SZs,1358
4
- waterfall/WatermarkingFnSquare.py,sha256=2PAO05DdKT02npo7GDf_82D520nP7kGAWK6H4E4JMt4,1638
5
- waterfall/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- waterfall/permute.py,sha256=uYKdmn4pGvjB6hInInLGxFIF6vt507lqJ_qe-ST1PFE,2783
7
- waterfall/watermark.py,sha256=avyQIFJBhqu_q_ZBp0-RWvAOIJmzJvVisbiIca2GPyA,14536
8
- waterfall-0.2.10.dist-info/METADATA,sha256=vUr4PSIrQvPdkBjrsQc7uKJj4GVUSBZnpvElPA8n1Uc,8768
9
- waterfall-0.2.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
- waterfall-0.2.10.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
- waterfall-0.2.10.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
- waterfall-0.2.10.dist-info/RECORD,,