waterfall 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,6 +24,24 @@ from waterfall.WatermarkingFnFourier import WatermarkingFnFourier
24
24
 
25
25
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
26
26
 
27
+ # Check transformers version
28
+ import transformers
29
+ from packaging import version
30
+
31
+ transformers_version = version.parse(transformers.__version__)
32
+ # Set model loading kwargs based on transformers version
33
+ if transformers_version >= version.parse("4.56.0"):
34
+ model_from_pretrained_kwargs = {"dtype": "auto"}
35
+ else:
36
+ model_from_pretrained_kwargs = {"torch_dtype": torch.bfloat16}
37
+ # Group beam search is shifted to transformers-community package in 4.57.0
38
+ use_custom_group_beam_search = transformers_version >= version.parse("4.57.0")
39
+ # BatchEncoding to() non_blocking added in 4.48.0
40
+ if transformers_version >= version.parse("4.48.0"):
41
+ batch_encoding_to_kwargs = {"non_blocking": True}
42
+ else:
43
+ batch_encoding_to_kwargs = {}
44
+
27
45
  class PerturbationProcessor(LogitsProcessor):
28
46
  def __init__(self,
29
47
  N : int = 32000, # Vocab size
@@ -134,7 +152,7 @@ class Watermarker:
134
152
  self.model = AutoModelForCausalLM.from_pretrained(
135
153
  model_name_or_path,
136
154
  device_map=device_map,
137
- torch_dtype=dtype,
155
+ **model_from_pretrained_kwargs,
138
156
  )
139
157
 
140
158
  def compute_phi(self, watermarkingFnClass = WatermarkingFnFourier) -> None:
@@ -279,6 +297,9 @@ class Watermarker:
279
297
  diversity_penalty = kwargs.get("diversity_penalty", None)
280
298
  if num_beams <= 1:
281
299
  kwargs["diversity_penalty"] = None
300
+ if use_custom_group_beam_search:
301
+ kwargs["custom_generate"]="transformers-community/group-beam-search"
302
+ kwargs["trust_remote_code"]=True
282
303
 
283
304
  if num_beams > 1 and temperature is not None and temperature != 1.0:
284
305
  logits_processor.append(TemperatureLogitsWarper(float(temperature)))
@@ -302,7 +323,7 @@ class Watermarker:
302
323
  tokd_inputs_order = range(len(tokd_inputs))
303
324
  tokd_input_batches = []
304
325
  for i in range(0, len(tokd_inputs), max_batch_size):
305
- batch = self.tokenizer.pad(tokd_inputs[i:i+max_batch_size], padding=True, padding_side="left").to(self.model.device, non_blocking=True)
326
+ batch = self.tokenizer.pad(tokd_inputs[i:i+max_batch_size], padding=True, padding_side="left").to(self.model.device, **batch_encoding_to_kwargs)
306
327
  tokd_input_batches.append(batch)
307
328
  torch.cuda.synchronize()
308
329
 
waterfall/watermark.py CHANGED
@@ -15,6 +15,21 @@ from waterfall.WatermarkingFnFourier import WatermarkingFnFourier
15
15
  from waterfall.WatermarkingFnSquare import WatermarkingFnSquare
16
16
  from waterfall.WatermarkerBase import Watermarker
17
17
 
18
+ # Check transformers version
19
+ import transformers
20
+ from packaging import version
21
+ transformers_version = version.parse(transformers.__version__)
22
+ if transformers_version >= version.parse("4.56.0"):
23
+ model_from_pretrained_kwargs = {"dtype": "auto"}
24
+ else:
25
+ model_from_pretrained_kwargs = {"torch_dtype": torch.bfloat16}
26
+ if transformers_version < version.parse("5.0.0") and transformers_version >= version.parse("4.50.0"):
27
+ additional_generation_config = {
28
+ "use_model_defaults": False,
29
+ }
30
+ else:
31
+ additional_generation_config = {}
32
+
18
33
  PROMPT = (
19
34
  "Paraphrase the user provided text while preserving semantic similarity. "
20
35
  "Do not include any other sentences in the response, such as explanations of the paraphrasing. "
@@ -160,7 +175,7 @@ def watermark_texts(
160
175
  return_scores=True,
161
176
  use_tqdm=use_tqdm,
162
177
  generation_config=generation_config,
163
- use_model_defaults=False,
178
+ **additional_generation_config,
164
179
  )
165
180
  T_ws = watermarked["text"]
166
181
  # Reshape T_ws to Queries X Beams
@@ -306,8 +321,8 @@ def main():
306
321
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
307
322
  model = AutoModelForCausalLM.from_pretrained(
308
323
  model_name_or_path,
309
- torch_dtype=torch.bfloat16,
310
324
  device_map=device,
325
+ **model_from_pretrained_kwargs,
311
326
  )
312
327
 
313
328
  watermarker = Watermarker(tokenizer=tokenizer, model=model, id=id, kappa=kappa, k_p=k_p, watermarkingFnClass=watermarkingFnClass)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: waterfall
3
- Version: 0.2.11
3
+ Version: 0.2.13
4
4
  Summary: Scalable Framework for Robust Text Watermarking and Provenance for LLMs
5
5
  Project-URL: Homepage, https://github.com/aoi3142/Waterfall
6
6
  Project-URL: Issues, https://github.com/aoi3142/Waterfall/issues
@@ -15,7 +15,7 @@ Requires-Dist: numpy>=1.25.0
15
15
  Requires-Dist: scipy>=1.13.0
16
16
  Requires-Dist: sentence-transformers>=3.0.0
17
17
  Requires-Dist: torch>=2.3.0
18
- Requires-Dist: transformers<4.57.0,>=4.43.1
18
+ Requires-Dist: transformers>=4.43.1
19
19
  Description-Content-Type: text/markdown
20
20
 
21
21
  # Waterfall: Scalable Framework for Robust Text Watermarking and Provenance for LLMs [EMNLP 2024 Main Long]
@@ -1,12 +1,12 @@
1
- waterfall/WatermarkerBase.py,sha256=pk2NB7J0oBLXcO0FIRBHllnSowbpeRNd9ZjvPuUOeeM,21945
1
+ waterfall/WatermarkerBase.py,sha256=H-tJ96WUihW30EFFnPn92pna4qQtyYjcWBlVVtY3oMM,22863
2
2
  waterfall/WatermarkingFn.py,sha256=-b-kGRdL0a7eKRqJmcHPAR_rCjxQYnsg1Ne6bTwBc1I,1931
3
3
  waterfall/WatermarkingFnFourier.py,sha256=QYayAQYwi1dQkDIyqmvhU568VhrVYTVy47HkI8F8SZs,1358
4
4
  waterfall/WatermarkingFnSquare.py,sha256=2PAO05DdKT02npo7GDf_82D520nP7kGAWK6H4E4JMt4,1638
5
5
  waterfall/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  waterfall/permute.py,sha256=uYKdmn4pGvjB6hInInLGxFIF6vt507lqJ_qe-ST1PFE,2783
7
- waterfall/watermark.py,sha256=avyQIFJBhqu_q_ZBp0-RWvAOIJmzJvVisbiIca2GPyA,14536
8
- waterfall-0.2.11.dist-info/METADATA,sha256=Ttp-F0sjA31gppuF6dfo5ze4HfNWIPYnNAf3VF0h02E,8768
9
- waterfall-0.2.11.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
10
- waterfall-0.2.11.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
- waterfall-0.2.11.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
- waterfall-0.2.11.dist-info/RECORD,,
7
+ waterfall/watermark.py,sha256=Qe_NSNH2XL5ZCf069fa438NOpNsju3l4kr2GDoKbuVU,15093
8
+ waterfall-0.2.13.dist-info/METADATA,sha256=VwO9mXTFEOoFxRASPt7qeZVCkMIbhH3_LkJ02yccOFM,8760
9
+ waterfall-0.2.13.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
10
+ waterfall-0.2.13.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
+ waterfall-0.2.13.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
+ waterfall-0.2.13.dist-info/RECORD,,