reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,82 +1,82 @@
1
- import os
2
- import tempfile
3
- import yaml
4
- from huggingface_hub import hf_hub_download, list_repo_files
5
-
6
- class HuggingfaceQueryMatcher:
7
- """Downloads the available configurations files to a temporary directory and provides functionality for filtering those configuration files matching user specified queries.
8
-
9
- Args:
10
- repo_id (str): The Hugging Face repository ID.
11
- config_dir (str): Directory within the repo where YAML files are stored.
12
- """
13
- def __init__(self, repo_id='valentinsingularity/reflectivity', config_dir='configs'):
14
- self.repo_id = repo_id
15
- self.config_dir = config_dir
16
- self.cache = {
17
- 'parsed_configs': None,
18
- 'temp_dir': None
19
- }
20
- self._renew_cache()
21
-
22
- def _renew_cache(self):
23
- temp_dir = tempfile.mkdtemp()
24
- print(f"Temporary directory created at: {temp_dir}")
25
-
26
- repo_files = list_repo_files(self.repo_id, repo_type='model')
27
- config_files = [file for file in repo_files if file.startswith(self.config_dir) and file.endswith('.yaml')]
28
-
29
- downloaded_files = []
30
- for file in config_files:
31
- file_path = hf_hub_download(repo_id=self.repo_id, filename=file, local_dir=temp_dir, repo_type='model')
32
- downloaded_files.append(file_path)
33
-
34
- parsed_configs = {}
35
- for file_path in downloaded_files:
36
- with open(file_path, 'r') as file:
37
- config_data = yaml.safe_load(file)
38
- file_name = os.path.basename(file_path)
39
- parsed_configs[file_name] = config_data
40
-
41
- self.cache['parsed_configs'] = parsed_configs
42
- self.cache['temp_dir'] = temp_dir
43
-
44
- def get_matching_configs(self, query):
45
- """retrieves configuration files that match the user specified query.
46
-
47
- Args:
48
- query (dict): Dictionary of key-value pairs to filter configurations, e.g. ``query = {'dset.prior_sampler.kwargs.max_num_layers': 3, 'dset.prior_sampler.kwargs.param_ranges.slds': [0., 100.]}``.
49
- For keys containing the ``param_ranges`` subkey a configuration is selected if the value of the query (i.e. desired parameter range)
50
- is a subrange of the parameter range in the configuration, in all other cases the values must match exactly.
51
-
52
- Returns:
53
- list: List of file names that match the query.
54
- """
55
-
56
- filtered_configs = []
57
-
58
- for file_name, config_data in self.cache['parsed_configs'].items():
59
- if self.matches_query(config_data, query):
60
- filtered_configs.append(file_name)
61
-
62
- return filtered_configs
63
-
64
- def matches_query(self, config_data, query):
65
- for q_key, q_value in query.items():
66
- keys = q_key.split('.')
67
- value = self.deep_get(config_data, keys)
68
- if 'param_ranges' in keys:
69
- if q_value[0] < value[0] or q_value[1] > value[1]:
70
- return False
71
- else:
72
- if value != q_value:
73
- return False
74
-
75
- return True
76
-
77
- def deep_get(self, d, keys):
78
- for key in keys:
79
- if isinstance(d, dict):
80
- d = d.get(key, None)
81
-
1
+ import os
2
+ import tempfile
3
+ import yaml
4
+ from huggingface_hub import hf_hub_download, list_repo_files
5
+
6
+ class HuggingfaceQueryMatcher:
7
+ """Downloads the available configurations files to a temporary directory and provides functionality for filtering those configuration files matching user specified queries.
8
+
9
+ Args:
10
+ repo_id (str): The Hugging Face repository ID.
11
+ config_dir (str): Directory within the repo where YAML files are stored.
12
+ """
13
+ def __init__(self, repo_id='valentinsingularity/reflectivity', config_dir='configs'):
14
+ self.repo_id = repo_id
15
+ self.config_dir = config_dir
16
+ self.cache = {
17
+ 'parsed_configs': None,
18
+ 'temp_dir': None
19
+ }
20
+ self._renew_cache()
21
+
22
+ def _renew_cache(self):
23
+ temp_dir = tempfile.mkdtemp()
24
+ print(f"Temporary directory created at: {temp_dir}")
25
+
26
+ repo_files = list_repo_files(self.repo_id, repo_type='model')
27
+ config_files = [file for file in repo_files if file.startswith(self.config_dir) and file.endswith('.yaml')]
28
+
29
+ downloaded_files = []
30
+ for file in config_files:
31
+ file_path = hf_hub_download(repo_id=self.repo_id, filename=file, local_dir=temp_dir, repo_type='model')
32
+ downloaded_files.append(file_path)
33
+
34
+ parsed_configs = {}
35
+ for file_path in downloaded_files:
36
+ with open(file_path, 'r') as file:
37
+ config_data = yaml.safe_load(file)
38
+ file_name = os.path.basename(file_path)
39
+ parsed_configs[file_name] = config_data
40
+
41
+ self.cache['parsed_configs'] = parsed_configs
42
+ self.cache['temp_dir'] = temp_dir
43
+
44
+ def get_matching_configs(self, query):
45
+ """retrieves configuration files that match the user specified query.
46
+
47
+ Args:
48
+ query (dict): Dictionary of key-value pairs to filter configurations, e.g. ``query = {'dset.prior_sampler.kwargs.max_num_layers': 3, 'dset.prior_sampler.kwargs.param_ranges.slds': [0., 100.]}``.
49
+ For keys containing the ``param_ranges`` subkey a configuration is selected if the value of the query (i.e. desired parameter range)
50
+ is a subrange of the parameter range in the configuration, in all other cases the values must match exactly.
51
+
52
+ Returns:
53
+ list: List of file names that match the query.
54
+ """
55
+
56
+ filtered_configs = []
57
+
58
+ for file_name, config_data in self.cache['parsed_configs'].items():
59
+ if self.matches_query(config_data, query):
60
+ filtered_configs.append(file_name)
61
+
62
+ return filtered_configs
63
+
64
+ def matches_query(self, config_data, query):
65
+ for q_key, q_value in query.items():
66
+ keys = q_key.split('.')
67
+ value = self.deep_get(config_data, keys)
68
+ if 'param_ranges' in keys:
69
+ if q_value[0] < value[0] or q_value[1] > value[1]:
70
+ return False
71
+ else:
72
+ if value != q_value:
73
+ return False
74
+
75
+ return True
76
+
77
+ def deep_get(self, d, keys):
78
+ for key in keys:
79
+ if isinstance(d, dict):
80
+ d = d.get(key, None)
81
+
82
82
  return d
@@ -1,43 +1,43 @@
1
- from time import perf_counter
2
- from contextlib import contextmanager
3
- from functools import wraps
4
-
5
-
6
- class EvaluateTime(list):
7
- @contextmanager
8
- def __call__(self, name: str, *args, **kwargs):
9
- start = perf_counter()
10
- yield
11
- self.action(perf_counter() - start, name, *args, **kwargs)
12
-
13
- @staticmethod
14
- def action(delta_time, name, *args, **kwargs):
15
- print(f"Time for {name} = {delta_time:.2f} sec")
16
-
17
- def __repr__(self):
18
- return f'EvaluateTime(total={sum(self)}, num_records={len(self)})'
19
-
20
-
21
- def print_time(name: str or callable):
22
- if isinstance(name, str):
23
- return _print_time_context(name)
24
- else:
25
- return _print_time_wrap(name)
26
-
27
-
28
- def _print_time_wrap(func, name: str = None):
29
- name = name or func.__name__
30
-
31
- @wraps(func)
32
- def wrapped_func(*args, **kwargs):
33
- with _print_time_context(name):
34
- return func(*args, **kwargs)
35
-
36
- return wrapped_func
37
-
38
-
39
- @contextmanager
40
- def _print_time_context(name: str):
41
- start = perf_counter()
42
- yield
43
- print(f"Time for {name} = {(perf_counter() - start):.2f} sec")
1
+ from time import perf_counter
2
+ from contextlib import contextmanager
3
+ from functools import wraps
4
+
5
+
6
+ class EvaluateTime(list):
7
+ @contextmanager
8
+ def __call__(self, name: str, *args, **kwargs):
9
+ start = perf_counter()
10
+ yield
11
+ self.action(perf_counter() - start, name, *args, **kwargs)
12
+
13
+ @staticmethod
14
+ def action(delta_time, name, *args, **kwargs):
15
+ print(f"Time for {name} = {delta_time:.2f} sec")
16
+
17
+ def __repr__(self):
18
+ return f'EvaluateTime(total={sum(self)}, num_records={len(self)})'
19
+
20
+
21
+ def print_time(name: str or callable):
22
+ if isinstance(name, str):
23
+ return _print_time_context(name)
24
+ else:
25
+ return _print_time_wrap(name)
26
+
27
+
28
+ def _print_time_wrap(func, name: str = None):
29
+ name = name or func.__name__
30
+
31
+ @wraps(func)
32
+ def wrapped_func(*args, **kwargs):
33
+ with _print_time_context(name):
34
+ return func(*args, **kwargs)
35
+
36
+ return wrapped_func
37
+
38
+
39
+ @contextmanager
40
+ def _print_time_context(name: str):
41
+ start = perf_counter()
42
+ yield
43
+ print(f"Time for {name} = {(perf_counter() - start):.2f} sec")
@@ -1,56 +1,56 @@
1
- import torch
2
- from torch import Tensor
3
-
4
- from reflectorch.data_generation.priors.utils import uniform_sampler
5
- from reflectorch.data_generation.priors.subprior_sampler import UniformSubPriorParams
6
- from reflectorch.data_generation.priors.params import Params
7
- from reflectorch.data_generation.likelihoods import LogLikelihood
8
-
9
-
10
- def simple_sampler_solution(
11
- likelihood: LogLikelihood,
12
- predicted_params: UniformSubPriorParams,
13
- total_min_bounds: Tensor,
14
- total_max_bounds: Tensor,
15
- num: int = 2 ** 15,
16
- coef: float = 0.1,
17
- ) -> UniformSubPriorParams:
18
- sampled_params_t = sample_around_params(predicted_params, total_min_bounds, total_max_bounds, num=num, coef=coef)
19
- sampled_params = Params.from_tensor(sampled_params_t)
20
- return get_best_mse_param(sampled_params, likelihood, predicted_params.min_bounds, predicted_params.max_bounds)
21
-
22
-
23
- def sample_around_params(predicted_params: UniformSubPriorParams,
24
- total_min_bounds: Tensor,
25
- total_max_bounds: Tensor,
26
- num: int = 2 ** 15,
27
- coef: float = 0.1,
28
- ) -> Tensor:
29
- params_t = predicted_params.as_tensor(add_bounds=False)
30
-
31
- delta = (predicted_params.max_bounds - predicted_params.min_bounds) * coef
32
- min_bounds = torch.clamp(params_t - delta, total_min_bounds, total_max_bounds)
33
- max_bounds = torch.clamp(params_t + delta, total_min_bounds, total_max_bounds)
34
-
35
- sampled_params_t = uniform_sampler(min_bounds, max_bounds, num, params_t.shape[-1])
36
- sampled_params_t[0] = params_t[0]
37
-
38
- return sampled_params_t
39
-
40
-
41
- def get_best_mse_param(
42
- params: Params,
43
- likelihood: LogLikelihood,
44
- min_bounds: Tensor = None,
45
- max_bounds: Tensor = None,
46
- ):
47
- sampled_curves = params.reflectivity(likelihood.q)
48
- log_probs = likelihood.calc_log_likelihood(sampled_curves)
49
- best_idx = torch.argmax(log_probs)
50
- best_param = params[best_idx:best_idx + 1]
51
-
52
- if min_bounds is not None:
53
- best_param = UniformSubPriorParams.from_tensor(
54
- torch.cat([best_param.as_tensor(), torch.atleast_2d(min_bounds), torch.atleast_2d(max_bounds)], -1)
55
- )
56
- return best_param
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ from reflectorch.data_generation.priors.utils import uniform_sampler
5
+ from reflectorch.data_generation.priors.subprior_sampler import UniformSubPriorParams
6
+ from reflectorch.data_generation.priors.params import Params
7
+ from reflectorch.data_generation.likelihoods import LogLikelihood
8
+
9
+
10
+ def simple_sampler_solution(
11
+ likelihood: LogLikelihood,
12
+ predicted_params: UniformSubPriorParams,
13
+ total_min_bounds: Tensor,
14
+ total_max_bounds: Tensor,
15
+ num: int = 2 ** 15,
16
+ coef: float = 0.1,
17
+ ) -> UniformSubPriorParams:
18
+ sampled_params_t = sample_around_params(predicted_params, total_min_bounds, total_max_bounds, num=num, coef=coef)
19
+ sampled_params = Params.from_tensor(sampled_params_t)
20
+ return get_best_mse_param(sampled_params, likelihood, predicted_params.min_bounds, predicted_params.max_bounds)
21
+
22
+
23
+ def sample_around_params(predicted_params: UniformSubPriorParams,
24
+ total_min_bounds: Tensor,
25
+ total_max_bounds: Tensor,
26
+ num: int = 2 ** 15,
27
+ coef: float = 0.1,
28
+ ) -> Tensor:
29
+ params_t = predicted_params.as_tensor(add_bounds=False)
30
+
31
+ delta = (predicted_params.max_bounds - predicted_params.min_bounds) * coef
32
+ min_bounds = torch.clamp(params_t - delta, total_min_bounds, total_max_bounds)
33
+ max_bounds = torch.clamp(params_t + delta, total_min_bounds, total_max_bounds)
34
+
35
+ sampled_params_t = uniform_sampler(min_bounds, max_bounds, num, params_t.shape[-1])
36
+ sampled_params_t[0] = params_t[0]
37
+
38
+ return sampled_params_t
39
+
40
+
41
+ def get_best_mse_param(
42
+ params: Params,
43
+ likelihood: LogLikelihood,
44
+ min_bounds: Tensor = None,
45
+ max_bounds: Tensor = None,
46
+ ):
47
+ sampled_curves = params.reflectivity(likelihood.q)
48
+ log_probs = likelihood.calc_log_likelihood(sampled_curves)
49
+ best_idx = torch.argmax(log_probs)
50
+ best_param = params[best_idx:best_idx + 1]
51
+
52
+ if min_bounds is not None:
53
+ best_param = UniformSubPriorParams.from_tensor(
54
+ torch.cat([best_param.as_tensor(), torch.atleast_2d(min_bounds), torch.atleast_2d(max_bounds)], -1)
55
+ )
56
+ return best_param