waterfall 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -27,13 +27,20 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
27
27
  # Check transformers version
28
28
  import transformers
29
29
  from packaging import version
30
- # Group beam search is shifted to transformers-community package in 4.57.0
31
- use_custom_group_beam_search = version.parse(transformers.__version__) >= version.parse("4.57.0")
30
+
31
+ transformers_version = version.parse(transformers.__version__)
32
32
  # Set model loading kwargs based on transformers version
33
- if version.parse(transformers.__version__) >= version.parse("4.56.0"):
33
+ if transformers_version >= version.parse("4.56.0"):
34
34
  model_from_pretrained_kwargs = {"dtype": "auto"}
35
35
  else:
36
36
  model_from_pretrained_kwargs = {"torch_dtype": torch.bfloat16}
37
+ # Group beam search is shifted to transformers-community package in 4.57.0
38
+ use_custom_group_beam_search = transformers_version >= version.parse("4.57.0")
39
+ # BatchEncoding to() non_blocking added in 4.48.0
40
+ if transformers_version >= version.parse("4.48.0"):
41
+ batch_encoding_to_kwargs = {"non_blocking": True}
42
+ else:
43
+ batch_encoding_to_kwargs = {}
37
44
 
38
45
  class PerturbationProcessor(LogitsProcessor):
39
46
  def __init__(self,
@@ -316,7 +323,7 @@ class Watermarker:
316
323
  tokd_inputs_order = range(len(tokd_inputs))
317
324
  tokd_input_batches = []
318
325
  for i in range(0, len(tokd_inputs), max_batch_size):
319
- batch = self.tokenizer.pad(tokd_inputs[i:i+max_batch_size], padding=True, padding_side="left").to(self.model.device, non_blocking=True)
326
+ batch = self.tokenizer.pad(tokd_inputs[i:i+max_batch_size], padding=True, padding_side="left").to(self.model.device, **batch_encoding_to_kwargs)
320
327
  tokd_input_batches.append(batch)
321
328
  torch.cuda.synchronize()
322
329
 
waterfall/watermark.py CHANGED
@@ -18,10 +18,17 @@ from waterfall.WatermarkerBase import Watermarker
18
18
  # Check transformers version
19
19
  import transformers
20
20
  from packaging import version
21
- if version.parse(transformers.__version__) >= version.parse("4.56.0"):
21
+ transformers_version = version.parse(transformers.__version__)
22
+ if transformers_version >= version.parse("4.56.0"):
22
23
  model_from_pretrained_kwargs = {"dtype": "auto"}
23
24
  else:
24
25
  model_from_pretrained_kwargs = {"torch_dtype": torch.bfloat16}
26
+ if transformers_version < version.parse("5.0.0") and transformers_version >= version.parse("4.50.0"):
27
+ additional_generation_config = {
28
+ "use_model_defaults": False,
29
+ }
30
+ else:
31
+ additional_generation_config = {}
25
32
 
26
33
  PROMPT = (
27
34
  "Paraphrase the user provided text while preserving semantic similarity. "
@@ -168,7 +175,7 @@ def watermark_texts(
168
175
  return_scores=True,
169
176
  use_tqdm=use_tqdm,
170
177
  generation_config=generation_config,
171
- use_model_defaults=False,
178
+ **additional_generation_config,
172
179
  )
173
180
  T_ws = watermarked["text"]
174
181
  # Reshape T_ws to Queries X Beams
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: waterfall
3
- Version: 0.2.12
3
+ Version: 0.2.13
4
4
  Summary: Scalable Framework for Robust Text Watermarking and Provenance for LLMs
5
5
  Project-URL: Homepage, https://github.com/aoi3142/Waterfall
6
6
  Project-URL: Issues, https://github.com/aoi3142/Waterfall/issues
@@ -1,12 +1,12 @@
1
- waterfall/WatermarkerBase.py,sha256=6O_S78dD3Jha2OkJK2u3euwCH93i-mTiYYGXosPDMig,22632
1
+ waterfall/WatermarkerBase.py,sha256=H-tJ96WUihW30EFFnPn92pna4qQtyYjcWBlVVtY3oMM,22863
2
2
  waterfall/WatermarkingFn.py,sha256=-b-kGRdL0a7eKRqJmcHPAR_rCjxQYnsg1Ne6bTwBc1I,1931
3
3
  waterfall/WatermarkingFnFourier.py,sha256=QYayAQYwi1dQkDIyqmvhU568VhrVYTVy47HkI8F8SZs,1358
4
4
  waterfall/WatermarkingFnSquare.py,sha256=2PAO05DdKT02npo7GDf_82D520nP7kGAWK6H4E4JMt4,1638
5
5
  waterfall/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  waterfall/permute.py,sha256=uYKdmn4pGvjB6hInInLGxFIF6vt507lqJ_qe-ST1PFE,2783
7
- waterfall/watermark.py,sha256=fvscFoSbM51YUuDaOmrOKGvwXO25VMgGJKTfAeeKCaA,14817
8
- waterfall-0.2.12.dist-info/METADATA,sha256=TBoeAFK8qkG-jIRi-OeKq4GbtFTDQkaUKyKuFNeDQHo,8760
9
- waterfall-0.2.12.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
10
- waterfall-0.2.12.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
- waterfall-0.2.12.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
- waterfall-0.2.12.dist-info/RECORD,,
7
+ waterfall/watermark.py,sha256=Qe_NSNH2XL5ZCf069fa438NOpNsju3l4kr2GDoKbuVU,15093
8
+ waterfall-0.2.13.dist-info/METADATA,sha256=VwO9mXTFEOoFxRASPt7qeZVCkMIbhH3_LkJ02yccOFM,8760
9
+ waterfall-0.2.13.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
10
+ waterfall-0.2.13.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
+ waterfall-0.2.13.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
+ waterfall-0.2.13.dist-info/RECORD,,