waterfall 0.2.8__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,7 @@ from tqdm import tqdm
15
15
  from transformers import AutoTokenizer, AutoModelForCausalLM
16
16
  from transformers.modeling_utils import PreTrainedModel
17
17
  from transformers.tokenization_utils_base import PreTrainedTokenizerBase, BatchEncoding
18
- from transformers.generation.logits_process import LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper
18
+ from transformers.generation.logits_process import LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
19
19
  from transformers.generation.configuration_utils import GenerationConfig
20
20
 
21
21
  from waterfall.permute import Permute
@@ -265,11 +265,23 @@ class Watermarker:
265
265
  generation_config: GenerationConfig = kwargs["generation_config"]
266
266
  top_k = generation_config.top_k
267
267
  top_p = generation_config.top_p
268
- generation_config.update(top_p=1.0)
268
+ temperature = generation_config.temperature
269
+ num_beams = generation_config.num_beams
270
+ diversity_penalty = generation_config.diversity_penalty
271
+ if num_beams <= 1:
272
+ diversity_penalty = None
273
+ generation_config.update(top_p=1.0, temperature=None, diversity_penalty=diversity_penalty)
269
274
  else:
270
275
  top_k = kwargs.pop("top_k", None)
271
276
  top_p = kwargs.pop("top_p", None)
272
-
277
+ temperature = kwargs.pop("temperature", 1.0)
278
+ num_beams = kwargs.pop("num_beams", 1)
279
+ diversity_penalty = kwargs.pop("diversity_penalty", None)
280
+ if num_beams <= 1:
281
+ kwargs["diversity_penalty"] = None
282
+
283
+ if num_beams > 1 and temperature != 1.0:
284
+ logits_processor.append(TemperatureLogitsWarper(float(temperature)))
273
285
  if top_k is not None and top_k != 0:
274
286
  logits_processor.append(TopKLogitsWarper(top_k))
275
287
  if top_p is not None and top_p < 1.0:
waterfall/watermark.py CHANGED
@@ -39,6 +39,17 @@ def detect_gpu() -> str:
39
39
  else:
40
40
  return 'cpu'
41
41
 
42
+ def del_cached_model():
43
+ global waterfall_cached_watermarking_model
44
+ if isinstance(waterfall_cached_watermarking_model, PreTrainedModel):
45
+ device = waterfall_cached_watermarking_model.device.type
46
+ waterfall_cached_watermarking_model = None
47
+ gc.collect()
48
+ if device == "cuda":
49
+ torch.cuda.empty_cache()
50
+ elif device == "mps":
51
+ torch.mps.empty_cache()
52
+
42
53
  def watermark_texts(
43
54
  T_os: List[str],
44
55
  id: Optional[int] = None,
@@ -60,6 +71,7 @@ def watermark_texts(
60
71
  beams_per_group: int = 2,
61
72
  diversity_penalty: float = 0.5,
62
73
  stop_at_double_newline: bool = True, # if True, will stop generation at the first double newline. Prevent repeated paraphrasing of the same text.
74
+ **kwargs,
63
75
  ) -> List[str]:
64
76
  if watermark_fn == 'fourier':
65
77
  watermarkingFnClass = WatermarkingFnFourier
@@ -72,16 +84,8 @@ def watermark_texts(
72
84
  if watermarker is None:
73
85
  assert model_path is not None, "model_path must be provided if watermarker is not passed"
74
86
  assert id is not None, "id must be provided if watermarker is not passed"
75
- global waterfall_cached_watermarking_model
76
-
77
87
  if isinstance(waterfall_cached_watermarking_model, PreTrainedModel) and waterfall_cached_watermarking_model.name_or_path != model_path:
78
- device = waterfall_cached_watermarking_model.device.type
79
- waterfall_cached_watermarking_model = None
80
- gc.collect()
81
- if device == "cuda":
82
- torch.cuda.empty_cache()
83
- elif device == "mps":
84
- torch.mps.empty_cache()
88
+ del_cached_model()
85
89
 
86
90
  if waterfall_cached_watermarking_model is None:
87
91
  model = model_path
@@ -91,7 +95,10 @@ def watermark_texts(
91
95
  watermarker = Watermarker(model=model, id=id, kappa=kappa, k_p=k_p, watermarkingFnClass=watermarkingFnClass)
92
96
  else:
93
97
  device = watermarker.model.device.type
94
- id = watermarker.id
98
+ if id is not None:
99
+ watermarker.set_id(id)
100
+ else:
101
+ id = watermarker.id
95
102
  waterfall_cached_watermarking_model = watermarker.model
96
103
 
97
104
  # Check if sts model is loaded
@@ -122,8 +129,8 @@ def watermark_texts(
122
129
  assert (do_sample and temperature is not None and top_p is not None and num_beam_groups == 1 and beams_per_group == 1), \
123
130
  "do_sample=True requires temperature, top_p, num_beam_groups=1 and beams_per_group=1"
124
131
  else: # Using beam search
125
- assert (not do_sample and temperature is None and top_p is None and num_beam_groups >= 1 and beams_per_group >= 1), \
126
- "do_sample=False requires temperature=None, top_p=None, num_beam_groups>=1 and beams_per_group>=1"
132
+ assert (not do_sample and num_beam_groups >= 1 and beams_per_group >= 1), \
133
+ "do_sample=False requires num_beam_groups>=1 and beams_per_group>=1"
127
134
 
128
135
  eos_token_id = watermarker.tokenizer.eos_token_id
129
136
  # add "\n\n" tokens to eos_token_id list
@@ -144,6 +151,7 @@ def watermark_texts(
144
151
  diversity_penalty=diversity_penalty,
145
152
  eos_token_id=eos_token_id,
146
153
  num_return_sequences=num_beam_groups * beams_per_group,
154
+ **kwargs
147
155
  )
148
156
 
149
157
  watermarked = watermarker.generate(
@@ -152,6 +160,7 @@ def watermark_texts(
152
160
  return_scores=True,
153
161
  use_tqdm=use_tqdm,
154
162
  generation_config=generation_config,
163
+ use_model_defaults=False,
155
164
  )
156
165
  T_ws = watermarked["text"]
157
166
  # Reshape T_ws to Queries X Beams
@@ -179,10 +188,7 @@ def verify_texts(texts: List[str], id: int,
179
188
  assert model_path is not None, "model_path must be provided if watermarker is not passed"
180
189
  watermarker = Watermarker(tokenizer=model_path)
181
190
 
182
- if k_p is None:
183
- k_p = watermarker.k_p
184
-
185
- verify_results = watermarker.verify(texts, id=[id], k_p=[k_p], return_extracted_k_p=return_extracted_k_p) # results are [text x id x k_p]
191
+ verify_results = watermarker.verify(texts, id=[id], k_p=k_p, return_extracted_k_p=return_extracted_k_p) # results are [text x id x k_p]
186
192
 
187
193
  if not return_extracted_k_p:
188
194
  return verify_results[:,0,0]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: waterfall
3
- Version: 0.2.8
3
+ Version: 0.2.9
4
4
  Summary: Scalable Framework for Robust Text Watermarking and Provenance for LLMs
5
5
  Project-URL: Homepage, https://github.com/aoi3142/Waterfall
6
6
  Project-URL: Issues, https://github.com/aoi3142/Waterfall/issues
@@ -76,7 +76,7 @@ Protecting intellectual property (IP) of text such as articles and code is incre
76
76
 
77
77
  # Using our code
78
78
 
79
- Install our package using `pip`
79
+ Install our [PyPI package](https://pypi.org/project/waterfall/) using `pip`
80
80
  ```sh
81
81
  pip install waterfall
82
82
  ```
@@ -1,12 +1,12 @@
1
- waterfall/WatermarkerBase.py,sha256=y7rJtP4Qf8GdJNdlZriX-JLClKxJRGiCtIkBtDoIqZw,21300
1
+ waterfall/WatermarkerBase.py,sha256=ZyH1AWlXl1ak9lCnhtBdsd4IXsHKFGNnx_yYJFnf_zw,22018
2
2
  waterfall/WatermarkingFn.py,sha256=-b-kGRdL0a7eKRqJmcHPAR_rCjxQYnsg1Ne6bTwBc1I,1931
3
3
  waterfall/WatermarkingFnFourier.py,sha256=QYayAQYwi1dQkDIyqmvhU568VhrVYTVy47HkI8F8SZs,1358
4
4
  waterfall/WatermarkingFnSquare.py,sha256=2PAO05DdKT02npo7GDf_82D520nP7kGAWK6H4E4JMt4,1638
5
5
  waterfall/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  waterfall/permute.py,sha256=uYKdmn4pGvjB6hInInLGxFIF6vt507lqJ_qe-ST1PFE,2783
7
- waterfall/watermark.py,sha256=IbH5r3oqjtKztDVryfDTr_NDn-CLZHow0S8nAEtZmdc,14420
8
- waterfall-0.2.8.dist-info/METADATA,sha256=QO_d3epwMllDm1vwvcOwwKL0ho5T6X4NEXXd7YHfiso,8723
9
- waterfall-0.2.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
- waterfall-0.2.8.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
- waterfall-0.2.8.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
- waterfall-0.2.8.dist-info/RECORD,,
7
+ waterfall/watermark.py,sha256=avyQIFJBhqu_q_ZBp0-RWvAOIJmzJvVisbiIca2GPyA,14536
8
+ waterfall-0.2.9.dist-info/METADATA,sha256=9sFs2kA_fj5CUD2UBt9uEuB1UTgn-ITJWUFWkksEqzE,8767
9
+ waterfall-0.2.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ waterfall-0.2.9.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
+ waterfall-0.2.9.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
+ waterfall-0.2.9.dist-info/RECORD,,