python-fedci 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fedci/testing.py CHANGED
@@ -38,7 +38,6 @@ class RegressionTest:
38
38
  self.llf: float = -float("inf")
39
39
  self.iterations: int = 0
40
40
  self.convergence_retry_count = 0
41
- self.reinit_beta = False
42
41
  self.lm_lambda = 1
43
42
 
44
43
  def __repr__(self):
@@ -128,14 +127,13 @@ class RegressionTest:
128
127
  xwz = sum([_update.xwz for _update in update])
129
128
  n = int(np.sum([_update.n for _update in update]).item())
130
129
 
131
- if abs(llf) < 1e-8 and np.allclose(xwx, np.zeros_like(xwx)) and np.allclose(xwz, np.zeros_like(xwz)):
130
+ if abs(llf) < 1e-10:
132
131
  self.early_stop = True
133
132
  return
134
133
 
135
- if not self.reinit_beta and np.allclose(xwz, np.zeros_like(xwz)):
136
- self.reinit_beta = True
137
- # readjust beta -> mostly an issue with small datasets and perfectly even distribution of categories
134
+ if self.iterations == 0 and np.allclose(xwz, np.zeros_like(xwz)): # readjust beta -> mostly an issue with small datasets and perfectly even distribution of categories
138
135
  self.beta = np.random.randn(self.dof, 1)
136
+ self.iterations += 1
139
137
  return
140
138
 
141
139
  if self.response_type == VariableType.CONTINUOS and n > 0:
@@ -187,9 +185,10 @@ class RegressionTest:
187
185
  self.lm_lambda /= 10
188
186
  beta = self._get_new_beta(xwx, xwz)
189
187
  if (
190
- self.iterations == 0
191
- and np.linalg.norm(self.beta - beta) < 1e-4
192
- or np.linalg.norm(self.beta - beta) < 1e-8
188
+ #self.iterations == 0
189
+ #and np.linalg.norm(self.beta - beta) < 1e-4
190
+ #or np.linalg.norm(self.beta - beta) < 1e-8
191
+ np.linalg.norm(self.beta - beta) < 1e-8
193
192
  ):
194
193
  self.early_stop = True
195
194
  return