reflectorch 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

@@ -29,6 +29,7 @@ from reflectorch.data_generation.q_generator import (
29
29
  ConstantQ,
30
30
  VariableQ,
31
31
  EquidistantQ,
32
+ MaskedVariableQ,
32
33
  )
33
34
  from reflectorch.data_generation.noise import (
34
35
  QNoiseGenerator,
@@ -78,6 +79,7 @@ __all__ = [
78
79
  "ConstantQ",
79
80
  "VariableQ",
80
81
  "EquidistantQ",
82
+ "MaskedVariableQ",
81
83
  "QNoiseGenerator",
82
84
  "IntensityNoiseGenerator",
83
85
  "MultiplicativeLogNormalNoiseGenerator",
@@ -37,7 +37,7 @@ class ParametricModel(object):
37
37
  self.max_num_layers = max_num_layers
38
38
  self._sampler_strategy = self._init_sampler_strategy(**kwargs)
39
39
 
40
- def _init_sampler_strategy(self, **kwargs):
40
+ def _init_sampler_strategy(self, nuisance_params_dim: int = 0, **kwargs):
41
41
  return BasicSamplerStrategy(**kwargs)
42
42
 
43
43
  @property
@@ -16,6 +16,7 @@ __all__ = [
16
16
  "VariableQ",
17
17
  "EquidistantQ",
18
18
  "ConstantAngle",
19
+ "MaskedVariableQ",
19
20
  ]
20
21
 
21
22
 
@@ -124,6 +125,11 @@ class VariableQ(QGenerator):
124
125
  q = torch.linspace(0, 1, n_q, device=self.device, dtype=self.dtype)
125
126
  elif self.mode == 'random':
126
127
  q = torch.rand(n_q, device=self.device, dtype=self.dtype).sort().values
128
+ elif self.mode == 'logspace':
129
+ q = torch.logspace(
130
+ start=torch.log10(torch.tensor(1e-4, dtype=self.dtype, device=self.device)),
131
+ end=torch.log10(torch.tensor(1.0, dtype=self.dtype, device=self.device)),
132
+ steps=n_q, dtype=self.dtype, device=self.device)
127
133
 
128
134
  q = q_min[:, None] + q * (q_max - q_min)[:, None]
129
135
 
@@ -198,49 +204,77 @@ class EquidistantQ(QGenerator):
198
204
  return qs
199
205
 
200
206
 
201
- class TransformerQ(QGenerator):
207
+ class MaskedVariableQ:
202
208
  def __init__(self,
203
- q_max: float = 0.2,
204
- num_values: Union[int, Tuple[int, int]] = (30, 512),
205
- min_dq_ratio: float = 5.,
206
- device=None,
207
- dtype=torch.float64,
208
- ):
209
- self.min_dq_ratio = min_dq_ratio
210
- self.q_max = q_max
211
- self._dq_range = q_max / num_values[1], q_max / num_values[0]
212
- self._num_values = num_values
209
+ q_min_range=(0.01, 0.03),
210
+ q_max_range=(0.1, 0.5),
211
+ n_q_range=(64, 256),
212
+ mode='equidistant',
213
+ shuffle_mask=False,
214
+ total_thickness_constraint=True,
215
+ min_points_per_fringe=4,
216
+ device=DEFAULT_DEVICE,
217
+ dtype=DEFAULT_DTYPE):
218
+ self.q_min_range = q_min_range
219
+ self.q_max_range = q_max_range
220
+ self.n_q_range = n_q_range
213
221
  self.device = device
214
222
  self.dtype = dtype
215
-
216
- def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
223
+ self.mode = mode
224
+ self.shuffle_mask = shuffle_mask
225
+ self.total_thickness_constraint = total_thickness_constraint
226
+ self.min_points_per_fringe = min_points_per_fringe
227
+
228
+ def get_batch(self, batch_size, context):
217
229
  assert context is not None
218
230
 
219
- params: BasicParams = context['params']
220
- total_thickness = params.thicknesses.sum(-1)
221
-
222
- assert total_thickness.shape[0] == batch_size
223
-
224
- min_dqs = torch.clamp(
225
- 2 * np.pi / total_thickness / self.min_dq_ratio, self._dq_range[0], self._dq_range[1] * 0.9
226
- )
227
-
228
- dqs = torch.rand_like(min_dqs) * (self._dq_range[1] - min_dqs) + min_dqs
229
-
230
- num_q_values = torch.clamp(self.q_max // dqs, *self._num_values).to(torch.int)
231
+ q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
232
+ q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
231
233
 
232
- q_values, mask = generate_q_padding_mask(num_q_values, self.q_max)
234
+ max_n_q = self.n_q_range[1]
233
235
 
234
- context['tgt_key_padding_mask'] = mask
235
- context['num_q_values'] = num_q_values
236
+ if self.mode == 'equidistant':
237
+ positions = torch.linspace(0, 1, max_n_q, device=self.device, dtype=self.dtype).expand(batch_size, max_n_q)
238
+ elif self.mode == 'random':
239
+ positions = torch.rand(batch_size, max_n_q, device=self.device, dtype=self.dtype)
240
+ positions, _ = positions.sort(dim=-1)
241
+ elif self.mode == 'mixed':
242
+ positions = torch.empty(batch_size, max_n_q, device=self.device, dtype=self.dtype)
243
+
244
+ half = batch_size // 2 # half batch gets equidistant
245
+ eq_pos = torch.linspace(0, 1, max_n_q, device=self.device, dtype=self.dtype).expand(half, max_n_q)
246
+ positions[:half] = eq_pos
247
+
248
+ rand_pos = torch.rand(batch_size - half, max_n_q, device=self.device, dtype=self.dtype) # other half gets sorted random
249
+ rand_pos, _ = rand_pos.sort(dim=-1)
250
+ positions[half:] = rand_pos
251
+ else:
252
+ raise ValueError(f"Unknown spacing mode: {self.mode}")
253
+
254
+ q = q_min[:, None] + positions * (q_max - q_min)[:, None]
255
+
256
+ n_qs = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (batch_size,), device=self.device)
257
+
258
+ if 'params' in context and self.total_thickness_constraint: ### N_points > 1 + (Q_spread * total_thickness * min_np_per_kiessing_fringe) / (2*pi)
259
+ d_total = context['params'].thicknesses.sum(-1)
260
+ limit = 1 + ((q_max - q_min) * d_total * self.min_points_per_fringe) / (2*np.pi)
261
+ limit = limit.ceil().int()
262
+ n_qs = torch.maximum(n_qs, limit)
263
+ n_qs = torch.clamp(n_qs, max=self.n_q_range[1])
264
+
265
+ indices = torch.arange(max_n_q, device=self.device).expand(batch_size, max_n_q)
266
+ valid_mask = indices < n_qs[:, None] # right side padding
267
+
268
+ if self.shuffle_mask: # shuffle valid positions (inter-spread padding)
269
+ perm = torch.argsort(torch.rand(batch_size, max_n_q, device=self.device), dim=-1)
270
+ valid_mask = torch.gather(valid_mask, dim=1, index=perm)
236
271
 
237
- return q_values
272
+ context['key_padding_mask'] = valid_mask
273
+ context['n_points'] = valid_mask.sum(dim=-1)
238
274
 
275
+ return q
276
+
277
+ def scale_q(self, q):
278
+ scaled_q_01 = (q - self.q_min_range[0]) / (self.q_max_range[1] - self.q_min_range[0])
239
279
 
240
- def generate_q_padding_mask(num_q_values: Tensor, q_max: float):
241
- batch_size = num_q_values.shape[0]
242
- dqs = (q_max / num_q_values)[:, None]
243
- q_values = torch.arange(1, num_q_values.max().item() + 1)[None].repeat(batch_size, 1) * dqs
244
- mask = (q_values > q_max + dqs / 2)
245
- q_values[mask] = 0.
246
- return q_values, mask
280
+ return 2.0 * (scaled_q_01 - 0.5)
@@ -163,6 +163,7 @@ def get_density_profiles(
163
163
  else:
164
164
  if ambient_sld.ndim == 1:
165
165
  ambient_sld = ambient_sld.unsqueeze(-1)
166
+ ambient_sld = ambient_sld.expand(bs, 1)
166
167
 
167
168
  slds_all = torch.cat([ambient_sld, slds], dim=-1) # new dimension: n+2
168
169
  d_rhos = torch.diff(slds_all, dim=-1) # (bs, n+1)