waterfall 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -228,6 +228,7 @@ class Watermarker:
228
228
  use_tqdm : bool = False,
229
229
  batched_generate : bool = True,
230
230
  discard_incomplete : bool = True,
231
+ logits_processor = [],
231
232
  **kwargs # Other generate parameters
232
233
  ) -> List[str] | dict: # Returns flattened list of query x beam
233
234
 
@@ -259,7 +260,6 @@ class Watermarker:
259
260
  squeezed_tokd_inputs.append(BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask}))
260
261
  tokd_inputs = squeezed_tokd_inputs
261
262
 
262
- logits_processor = []
263
263
  # Ensure top_k and top_p happens before watermarking
264
264
  if "generation_config" in kwargs:
265
265
  generation_config: GenerationConfig = kwargs["generation_config"]
@@ -275,12 +275,12 @@ class Watermarker:
275
275
  top_k = kwargs.pop("top_k", None)
276
276
  top_p = kwargs.pop("top_p", None)
277
277
  temperature = kwargs.pop("temperature", 1.0)
278
- num_beams = kwargs.pop("num_beams", 1)
279
- diversity_penalty = kwargs.pop("diversity_penalty", None)
278
+ num_beams = kwargs.get("num_beams", 1)
279
+ diversity_penalty = kwargs.get("diversity_penalty", None)
280
280
  if num_beams <= 1:
281
281
  kwargs["diversity_penalty"] = None
282
282
 
283
- if num_beams > 1 and temperature != 1.0:
283
+ if num_beams > 1 and temperature is not None and temperature != 1.0:
284
284
  logits_processor.append(TemperatureLogitsWarper(float(temperature)))
285
285
  if top_k is not None and top_k != 0:
286
286
  logits_processor.append(TopKLogitsWarper(top_k))
@@ -351,10 +351,6 @@ class Watermarker:
351
351
  decoded_output = [i.strip() for i in decoded_output]
352
352
  return_dict["text"] = decoded_output
353
353
 
354
- if is_single:
355
- for k, v in return_dict.items():
356
- return_dict[k] = v[0]
357
-
358
354
  if return_text and len(return_dict) == 1:
359
355
  return decoded_output
360
356
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: waterfall
3
- Version: 0.2.9
3
+ Version: 0.2.11
4
4
  Summary: Scalable Framework for Robust Text Watermarking and Provenance for LLMs
5
5
  Project-URL: Homepage, https://github.com/aoi3142/Waterfall
6
6
  Project-URL: Issues, https://github.com/aoi3142/Waterfall/issues
@@ -15,7 +15,7 @@ Requires-Dist: numpy>=1.25.0
15
15
  Requires-Dist: scipy>=1.13.0
16
16
  Requires-Dist: sentence-transformers>=3.0.0
17
17
  Requires-Dist: torch>=2.3.0
18
- Requires-Dist: transformers<4.55.0,>=4.43.1
18
+ Requires-Dist: transformers<4.57.0,>=4.43.1
19
19
  Description-Content-Type: text/markdown
20
20
 
21
21
  # Waterfall: Scalable Framework for Robust Text Watermarking and Provenance for LLMs [EMNLP 2024 Main Long]
@@ -0,0 +1,12 @@
1
+ waterfall/WatermarkerBase.py,sha256=pk2NB7J0oBLXcO0FIRBHllnSowbpeRNd9ZjvPuUOeeM,21945
2
+ waterfall/WatermarkingFn.py,sha256=-b-kGRdL0a7eKRqJmcHPAR_rCjxQYnsg1Ne6bTwBc1I,1931
3
+ waterfall/WatermarkingFnFourier.py,sha256=QYayAQYwi1dQkDIyqmvhU568VhrVYTVy47HkI8F8SZs,1358
4
+ waterfall/WatermarkingFnSquare.py,sha256=2PAO05DdKT02npo7GDf_82D520nP7kGAWK6H4E4JMt4,1638
5
+ waterfall/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ waterfall/permute.py,sha256=uYKdmn4pGvjB6hInInLGxFIF6vt507lqJ_qe-ST1PFE,2783
7
+ waterfall/watermark.py,sha256=avyQIFJBhqu_q_ZBp0-RWvAOIJmzJvVisbiIca2GPyA,14536
8
+ waterfall-0.2.11.dist-info/METADATA,sha256=Ttp-F0sjA31gppuF6dfo5ze4HfNWIPYnNAf3VF0h02E,8768
9
+ waterfall-0.2.11.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
10
+ waterfall-0.2.11.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
+ waterfall-0.2.11.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
+ waterfall-0.2.11.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,12 +0,0 @@
1
- waterfall/WatermarkerBase.py,sha256=ZyH1AWlXl1ak9lCnhtBdsd4IXsHKFGNnx_yYJFnf_zw,22018
2
- waterfall/WatermarkingFn.py,sha256=-b-kGRdL0a7eKRqJmcHPAR_rCjxQYnsg1Ne6bTwBc1I,1931
3
- waterfall/WatermarkingFnFourier.py,sha256=QYayAQYwi1dQkDIyqmvhU568VhrVYTVy47HkI8F8SZs,1358
4
- waterfall/WatermarkingFnSquare.py,sha256=2PAO05DdKT02npo7GDf_82D520nP7kGAWK6H4E4JMt4,1638
5
- waterfall/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- waterfall/permute.py,sha256=uYKdmn4pGvjB6hInInLGxFIF6vt507lqJ_qe-ST1PFE,2783
7
- waterfall/watermark.py,sha256=avyQIFJBhqu_q_ZBp0-RWvAOIJmzJvVisbiIca2GPyA,14536
8
- waterfall-0.2.9.dist-info/METADATA,sha256=9sFs2kA_fj5CUD2UBt9uEuB1UTgn-ITJWUFWkksEqzE,8767
9
- waterfall-0.2.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
- waterfall-0.2.9.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
- waterfall-0.2.9.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
- waterfall-0.2.9.dist-info/RECORD,,