waterfall 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,6 +24,17 @@ from waterfall.WatermarkingFnFourier import WatermarkingFnFourier
24
24
 
25
25
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
26
26
 
27
+ # Check transformers version
28
+ import transformers
29
+ from packaging import version
30
+ # Group beam search is shifted to transformers-community package in 4.57.0
31
+ use_custom_group_beam_search = version.parse(transformers.__version__) >= version.parse("4.57.0")
32
+ # Set model loading kwargs based on transformers version
33
+ if version.parse(transformers.__version__) >= version.parse("4.56.0"):
34
+ model_from_pretrained_kwargs = {"dtype": "auto"}
35
+ else:
36
+ model_from_pretrained_kwargs = {"torch_dtype": torch.bfloat16}
37
+
27
38
  class PerturbationProcessor(LogitsProcessor):
28
39
  def __init__(self,
29
40
  N : int = 32000, # Vocab size
@@ -134,7 +145,7 @@ class Watermarker:
134
145
  self.model = AutoModelForCausalLM.from_pretrained(
135
146
  model_name_or_path,
136
147
  device_map=device_map,
137
- torch_dtype=dtype,
148
+ **model_from_pretrained_kwargs,
138
149
  )
139
150
 
140
151
  def compute_phi(self, watermarkingFnClass = WatermarkingFnFourier) -> None:
@@ -279,6 +290,9 @@ class Watermarker:
279
290
  diversity_penalty = kwargs.get("diversity_penalty", None)
280
291
  if num_beams <= 1:
281
292
  kwargs["diversity_penalty"] = None
293
+ if use_custom_group_beam_search:
294
+ kwargs["custom_generate"]="transformers-community/group-beam-search"
295
+ kwargs["trust_remote_code"]=True
282
296
 
283
297
  if num_beams > 1 and temperature is not None and temperature != 1.0:
284
298
  logits_processor.append(TemperatureLogitsWarper(float(temperature)))
waterfall/watermark.py CHANGED
@@ -15,6 +15,14 @@ from waterfall.WatermarkingFnFourier import WatermarkingFnFourier
15
15
  from waterfall.WatermarkingFnSquare import WatermarkingFnSquare
16
16
  from waterfall.WatermarkerBase import Watermarker
17
17
 
18
+ # Check transformers version
19
+ import transformers
20
+ from packaging import version
21
+ if version.parse(transformers.__version__) >= version.parse("4.56.0"):
22
+ model_from_pretrained_kwargs = {"dtype": "auto"}
23
+ else:
24
+ model_from_pretrained_kwargs = {"torch_dtype": torch.bfloat16}
25
+
18
26
  PROMPT = (
19
27
  "Paraphrase the user provided text while preserving semantic similarity. "
20
28
  "Do not include any other sentences in the response, such as explanations of the paraphrasing. "
@@ -306,8 +314,8 @@ def main():
306
314
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
307
315
  model = AutoModelForCausalLM.from_pretrained(
308
316
  model_name_or_path,
309
- torch_dtype=torch.bfloat16,
310
317
  device_map=device,
318
+ **model_from_pretrained_kwargs,
311
319
  )
312
320
 
313
321
  watermarker = Watermarker(tokenizer=tokenizer, model=model, id=id, kappa=kappa, k_p=k_p, watermarkingFnClass=watermarkingFnClass)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: waterfall
3
- Version: 0.2.11
3
+ Version: 0.2.12
4
4
  Summary: Scalable Framework for Robust Text Watermarking and Provenance for LLMs
5
5
  Project-URL: Homepage, https://github.com/aoi3142/Waterfall
6
6
  Project-URL: Issues, https://github.com/aoi3142/Waterfall/issues
@@ -15,7 +15,7 @@ Requires-Dist: numpy>=1.25.0
15
15
  Requires-Dist: scipy>=1.13.0
16
16
  Requires-Dist: sentence-transformers>=3.0.0
17
17
  Requires-Dist: torch>=2.3.0
18
- Requires-Dist: transformers<4.57.0,>=4.43.1
18
+ Requires-Dist: transformers>=4.43.1
19
19
  Description-Content-Type: text/markdown
20
20
 
21
21
  # Waterfall: Scalable Framework for Robust Text Watermarking and Provenance for LLMs [EMNLP 2024 Main Long]
@@ -1,12 +1,12 @@
1
- waterfall/WatermarkerBase.py,sha256=pk2NB7J0oBLXcO0FIRBHllnSowbpeRNd9ZjvPuUOeeM,21945
1
+ waterfall/WatermarkerBase.py,sha256=6O_S78dD3Jha2OkJK2u3euwCH93i-mTiYYGXosPDMig,22632
2
2
  waterfall/WatermarkingFn.py,sha256=-b-kGRdL0a7eKRqJmcHPAR_rCjxQYnsg1Ne6bTwBc1I,1931
3
3
  waterfall/WatermarkingFnFourier.py,sha256=QYayAQYwi1dQkDIyqmvhU568VhrVYTVy47HkI8F8SZs,1358
4
4
  waterfall/WatermarkingFnSquare.py,sha256=2PAO05DdKT02npo7GDf_82D520nP7kGAWK6H4E4JMt4,1638
5
5
  waterfall/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  waterfall/permute.py,sha256=uYKdmn4pGvjB6hInInLGxFIF6vt507lqJ_qe-ST1PFE,2783
7
- waterfall/watermark.py,sha256=avyQIFJBhqu_q_ZBp0-RWvAOIJmzJvVisbiIca2GPyA,14536
8
- waterfall-0.2.11.dist-info/METADATA,sha256=Ttp-F0sjA31gppuF6dfo5ze4HfNWIPYnNAf3VF0h02E,8768
9
- waterfall-0.2.11.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
10
- waterfall-0.2.11.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
- waterfall-0.2.11.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
- waterfall-0.2.11.dist-info/RECORD,,
7
+ waterfall/watermark.py,sha256=fvscFoSbM51YUuDaOmrOKGvwXO25VMgGJKTfAeeKCaA,14817
8
+ waterfall-0.2.12.dist-info/METADATA,sha256=TBoeAFK8qkG-jIRi-OeKq4GbtFTDQkaUKyKuFNeDQHo,8760
9
+ waterfall-0.2.12.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
10
+ waterfall-0.2.12.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
+ waterfall-0.2.12.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
+ waterfall-0.2.12.dist-info/RECORD,,