waterfall 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -92,10 +92,9 @@ class Watermarker:
92
92
  n_gram : int = 2,
93
93
  watermarkingFnClass = WatermarkingFnFourier,
94
94
  device = None,
95
+ dtype = torch.bfloat16,
95
96
  ) -> None:
96
97
  assert kappa >= 0, f"kappa must be >= 0, value provided is {kappa}"
97
-
98
- self.id = id
99
98
  self.k_p = k_p
100
99
  self.n_gram = n_gram
101
100
  self.kappa = kappa
@@ -116,21 +115,26 @@ class Watermarker:
116
115
 
117
116
  self.N = self.tokenizer.vocab_size
118
117
 
119
- self.logits_processor = PerturbationProcessor(N = self.N, id = self.id)
120
-
121
118
  if isinstance(model, str):
122
- self.load_model(model, device_map=device)
119
+ self.load_model(model, device_map=device, dtype=dtype)
123
120
  else:
124
121
  self.model = model
125
122
 
126
123
  assert (self.model is None) or isinstance(self.model, PreTrainedModel), f"model must be a transformers model, value provided is {type(self.model)}" # argument order for tokenizer and model were swapped since the original code
127
124
 
128
- self.compute_phi(watermarkingFnClass)
125
+ self.watermarkingFnClass = watermarkingFnClass
126
+ self.set_id(id)
127
+
128
+ def set_id(self, id : int):
129
+ self.id = id
130
+ self.logits_processor = PerturbationProcessor(N = self.N, id = self.id)
131
+ self.compute_phi(self.watermarkingFnClass)
129
132
 
130
- def load_model(self, model_name_or_path : str, device_map : str = "auto"):
133
+ def load_model(self, model_name_or_path : str, device_map : str = "auto", dtype = torch.bfloat16):
131
134
  self.model = AutoModelForCausalLM.from_pretrained(
132
135
  model_name_or_path,
133
136
  device_map=device_map,
137
+ torch_dtype=dtype,
134
138
  )
135
139
 
136
140
  def compute_phi(self, watermarkingFnClass = WatermarkingFnFourier) -> None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: waterfall
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: Scalable Framework for Robust Text Watermarking and Provenance for LLMs
5
5
  Project-URL: Homepage, https://github.com/aoi3142/Waterfall
6
6
  Project-URL: Issues, https://github.com/aoi3142/Waterfall/issues
@@ -1,12 +1,12 @@
1
- waterfall/WatermarkerBase.py,sha256=A2VRfsnBfz6-8DSL2NKQZdM1OLI0sQ73qjYaV6rIgJ0,20822
1
+ waterfall/WatermarkerBase.py,sha256=nx2HhNCNy3yAhttq77cFnMmxlGucV_T8iVfbviAmClI,21046
2
2
  waterfall/WatermarkingFn.py,sha256=-b-kGRdL0a7eKRqJmcHPAR_rCjxQYnsg1Ne6bTwBc1I,1931
3
3
  waterfall/WatermarkingFnFourier.py,sha256=QYayAQYwi1dQkDIyqmvhU568VhrVYTVy47HkI8F8SZs,1358
4
4
  waterfall/WatermarkingFnSquare.py,sha256=2PAO05DdKT02npo7GDf_82D520nP7kGAWK6H4E4JMt4,1638
5
5
  waterfall/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  waterfall/permute.py,sha256=uYKdmn4pGvjB6hInInLGxFIF6vt507lqJ_qe-ST1PFE,2783
7
7
  waterfall/watermark.py,sha256=IbH5r3oqjtKztDVryfDTr_NDn-CLZHow0S8nAEtZmdc,14420
8
- waterfall-0.2.1.dist-info/METADATA,sha256=Mzyp7Nw395RLCN3wnzp2StEpKZEN2erb5BvCOd5Z-4I,8722
9
- waterfall-0.2.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
- waterfall-0.2.1.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
- waterfall-0.2.1.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
- waterfall-0.2.1.dist-info/RECORD,,
8
+ waterfall-0.2.3.dist-info/METADATA,sha256=8FvklSdyfGq495wvxHZeTT-8nb2xF0tzsH64_tbv6xc,8722
9
+ waterfall-0.2.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ waterfall-0.2.3.dist-info/entry_points.txt,sha256=XXnUzuWXu2nc9j4WAll9tq6HyodN_8WJLjeG0O4Y2Gw,60
11
+ waterfall-0.2.3.dist-info/licenses/LICENSE,sha256=zAtaO-k41Q-Q4Etl4bzuh7pgNJsPH-dYfzvznRa0OvM,11341
12
+ waterfall-0.2.3.dist-info/RECORD,,