mergeron 2025.739265.0__tar.gz → 2025.739290.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mergeron might be problematic. Click here for more details.

Files changed (23) hide show
  1. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/PKG-INFO +2 -1
  2. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/pyproject.toml +4 -3
  3. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/__init__.py +51 -2
  4. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/core/__init__.py +2 -2
  5. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/core/empirical_margin_distribution.py +2 -2
  6. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/core/ftc_merger_investigations_data.py +1 -1
  7. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/core/guidelines_boundaries.py +18 -15
  8. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/core/guidelines_boundary_functions.py +1 -1
  9. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/core/guidelines_boundary_functions_extra.py +4 -4
  10. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/core/pseudorandom_numbers.py +146 -108
  11. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/demo/visualize_empirical_margin_distribution.py +1 -1
  12. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/gen/__init__.py +226 -88
  13. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/gen/data_generation.py +144 -177
  14. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/gen/data_generation_functions.py +73 -122
  15. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/gen/enforcement_stats.py +30 -6
  16. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/gen/upp_tests.py +9 -10
  17. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/README.rst +0 -0
  18. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/data/__init__.py +0 -0
  19. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/data/damodaran_margin_data.xls +0 -0
  20. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/data/damodaran_margin_data_dict.msgpack +0 -0
  21. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/data/ftc_invdata.msgpack +0 -0
  22. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/demo/__init__.py +0 -0
  23. {mergeron-2025.739265.0 → mergeron-2025.739290.0}/src/mergeron/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: mergeron
3
- Version: 2025.739265.0
3
+ Version: 2025.739290.0
4
4
  Summary: Analyze merger enforcement policy using Python
5
5
  License: MIT
6
6
  Keywords: merger policy analysis,merger guidelines,merger screening,policy presumptions,concentration standards,upward pricing pressure,GUPPI
@@ -30,6 +30,7 @@ Requires-Dist: matplotlib (>=3.8)
30
30
  Requires-Dist: mpmath (>=1.3)
31
31
  Requires-Dist: msgpack (>=1.0)
32
32
  Requires-Dist: msgpack-numpy (>=0.4)
33
+ Requires-Dist: ruamel-yaml (>=0.18.10,<0.19.0)
33
34
  Requires-Dist: scipy (>=1.12)
34
35
  Requires-Dist: sympy (>=1.12)
35
36
  Requires-Dist: tables (>=3.10.1)
@@ -13,7 +13,7 @@ keywords = [
13
13
  "upward pricing pressure",
14
14
  "GUPPI",
15
15
  ]
16
- version = "2025.739265.0"
16
+ version = "2025.739290.0"
17
17
 
18
18
  # Classifiers list: https://pypi.org/classifiers/
19
19
  classifiers = [
@@ -57,10 +57,12 @@ certifi = ">=2023.11.17"
57
57
  types-beautifulsoup4 = ">=4.11.2"
58
58
  xlrd = "^2.0.1" # Needed to read margin data
59
59
  urllib3 = "^2.2.2"
60
+ ruamel-yaml = "^0.18.10"
60
61
 
61
62
 
62
63
  [tool.poetry.group.dev.dependencies]
63
64
  icecream = ">=2.1.0"
65
+ jinja2 = ">=3.1.5"
64
66
  mypy = ">=1.8"
65
67
  openpyxl = ">=3.1.2"
66
68
  pendulum = ">=3.0.0"
@@ -74,8 +76,7 @@ sphinx-autoapi = ">=3.0"
74
76
  sphinx-immaterial = ">=0.11"
75
77
  pipdeptree = ">=2.15.1"
76
78
  types-openpyxl = ">=3.0.0"
77
- pyright = "^1.1.380"
78
-
79
+ virtualenv = ">=20.28.0"
79
80
  [tool.ruff]
80
81
 
81
82
  # Exclude a variety of commonly ignored directories.
@@ -1,18 +1,25 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import enum
4
+ from multiprocessing import cpu_count
4
5
  from pathlib import Path
5
6
  from typing import Literal
6
7
 
7
8
  import numpy as np
8
9
  from numpy.typing import NDArray
10
+ from ruamel import yaml
9
11
 
10
12
  _PKG_NAME: str = Path(__file__).parent.stem
11
13
 
12
- VERSION = "2025.739265.0"
14
+ VERSION = "2025.739290.0"
13
15
 
14
16
  __version__ = VERSION
15
17
 
18
+ this_yaml = yaml.YAML(typ="safe", pure=True)
19
+ this_yaml.constructor.deep_construct = True
20
+ this_yaml.indent(mapping=2, sequence=4, offset=2)
21
+
22
+
16
23
  DATA_DIR: Path = Path.home() / _PKG_NAME
17
24
  """
18
25
  Defines a subdirectory named for this package in the user's home path.
@@ -22,6 +29,13 @@ If the subdirectory doesn't exist, it is created on package invocation.
22
29
  if not DATA_DIR.is_dir():
23
30
  DATA_DIR.mkdir(parents=False)
24
31
 
32
+ DEFAULT_REC_RATIO = 0.85
33
+
34
+ EMPTY_ARRAYDOUBLE = np.array([], float)
35
+ EMPTY_ARRAYINT = np.array([], int)
36
+
37
+ NTHREADS = 2 * cpu_count()
38
+
25
39
  np.set_printoptions(precision=24, floatmode="fixed")
26
40
 
27
41
  type HMGPubYear = Literal[1982, 1984, 1992, 2010, 2023]
@@ -33,7 +47,24 @@ type ArrayINT = NDArray[np.intp]
33
47
  type ArrayDouble = NDArray[np.float64]
34
48
  type ArrayBIGINT = NDArray[np.int64]
35
49
 
36
- DEFAULT_REC_RATIO = 0.85
50
+ ## Add yaml representer, constructor for NoneType
51
+ (_, _) = (
52
+ this_yaml.representer.add_representer(
53
+ type(None), lambda _r, _d: _r.represent_scalar("!None", "none")
54
+ ),
55
+ this_yaml.constructor.add_constructor("!None", lambda _c, _n, /: None),
56
+ )
57
+
58
+ ## Add yaml representer, constructor for ndarray
59
+ (_, _) = (
60
+ this_yaml.representer.add_representer(
61
+ np.ndarray,
62
+ lambda _r, _d: _r.represent_sequence("!ndarray", (_d.tolist(), _d.dtype.str)),
63
+ ),
64
+ this_yaml.constructor.add_constructor(
65
+ "!ndarray", lambda _c, _n, /: np.array(*_c.construct_sequence(_n))
66
+ ),
67
+ )
37
68
 
38
69
 
39
70
  @enum.unique
@@ -63,3 +94,21 @@ class UPPAggrSelector(enum.StrEnum):
63
94
  OSA = "own-share weighted average"
64
95
  OSD = "own-share weighted distance"
65
96
  OSG = "own-share weighted geometric mean"
97
+
98
+
99
+ for _typ in (RECForm, UPPAggrSelector):
100
+ # NOTE: If additional enums are defined in this module,
101
+ # add themn to the list above
102
+
103
+ _, _ = (
104
+ this_yaml.representer.add_representer(
105
+ _typ,
106
+ lambda _r, _d: _r.represent_scalar(f"!{_d.__class__.__name__}", _d.name),
107
+ ),
108
+ this_yaml.constructor.add_constructor(
109
+ f"!{_typ.__name__}",
110
+ lambda _c, _n, /: getattr(
111
+ globals().get(_n.tag.lstrip("!")), _c.construct_scalar(_n)
112
+ ),
113
+ ),
114
+ )
@@ -4,5 +4,5 @@ from .. import VERSION # noqa: TID252
4
4
 
5
5
  __version__ = VERSION
6
6
 
7
- type MPFloat = mp.mpf # pyright: ignore
8
- type MPMatrix = mp.matrix # pyright: ignore
7
+ type MPFloat = mp.mpf # type: ignore
8
+ type MPMatrix = mp.matrix # type: ignore
@@ -132,7 +132,7 @@ def mgn_data_getter( # noqa: PLR0912
132
132
  _xl_row[1] = int(_xl_row[1])
133
133
  _mgn_dict[_xl_row[0]] = dict(zip(_mgn_row_keys[1:], _xl_row[1:], strict=True))
134
134
 
135
- _ = _data_archive_path.write_bytes(msgpack.packb(_mgn_dict)) # pyright: ignore
135
+ _ = _data_archive_path.write_bytes(msgpack.packb(_mgn_dict))
136
136
 
137
137
  return MappingProxyType(_mgn_dict)
138
138
 
@@ -221,7 +221,7 @@ def mgn_data_resampler(
221
221
  _x, _w, _ = mgn_data_builder(mgn_data_getter())
222
222
 
223
223
  _mgn_kde = stats.gaussian_kde(_x, weights=_w, bw_method="silverman")
224
- _mgn_kde.set_bandwidth(bw_method=_mgn_kde.factor / 3.0) # pyright: ignore
224
+ _mgn_kde.set_bandwidth(bw_method=_mgn_kde.factor / 3.0)
225
225
 
226
226
  if isinstance(_sample_size, int):
227
227
  return np.array(
@@ -198,7 +198,7 @@ def construct_data(
198
198
  )
199
199
  }
200
200
 
201
- _ = INVDATA_ARCHIVE_PATH.write_bytes(msgpack.packb(_invdata)) # pyright: ignore
201
+ _ = INVDATA_ARCHIVE_PATH.write_bytes(msgpack.packb(_invdata))
202
202
 
203
203
  return MappingProxyType(_invdata)
204
204
 
@@ -62,7 +62,7 @@ class GuidelinesThresholds:
62
62
  Year of publication of the Guidelines
63
63
  """
64
64
 
65
- safeharbor: HMGThresholds = field(kw_only=True, default=None)
65
+ safeharbor: HMGThresholds = field(kw_only=True, default=None, init=False)
66
66
  """
67
67
  Negative presumption quantified on various measures
68
68
 
@@ -70,7 +70,7 @@ class GuidelinesThresholds:
70
70
  diversion ratio limit, CMCR, and IPR
71
71
  """
72
72
 
73
- presumption: HMGThresholds = field(kw_only=True, default=None)
73
+ presumption: HMGThresholds = field(kw_only=True, default=None, init=False)
74
74
  """
75
75
  Presumption of harm defined in HMG
76
76
 
@@ -78,7 +78,7 @@ class GuidelinesThresholds:
78
78
  diversion ratio limit, CMCR, and IPR
79
79
  """
80
80
 
81
- imputed_presumption: HMGThresholds = field(kw_only=True, default=None)
81
+ imputed_presumption: HMGThresholds = field(kw_only=True, default=None, init=False)
82
82
  """
83
83
  Presumption of harm imputed from guidelines
84
84
 
@@ -153,25 +153,28 @@ class ConcentrationBoundary:
153
153
  """Concentration parameters, boundary coordinates, and area under concentration boundary."""
154
154
 
155
155
  measure_name: Literal[
156
- "ΔHHI", "Combined share", "Pre-merger HHI", "Post-merger HHI"
156
+ "ΔHHI",
157
+ "Combined share",
158
+ "Pre-merger HHI Contribution",
159
+ "Post-merger HHI Contribution",
157
160
  ] = field(kw_only=False, default="ΔHHI")
158
161
 
159
- @measure_name.validator # pyright: ignore
160
- def __mnv(
162
+ @measure_name.validator
163
+ def _mnv(
161
164
  _instance: ConcentrationBoundary, _attribute: Attribute[str], _value: str, /
162
165
  ) -> None:
163
166
  if _value not in (
164
167
  "ΔHHI",
165
168
  "Combined share",
166
- "Pre-merger HHI",
167
- "Post-merger HHI",
169
+ "Pre-merger HHI Contribution",
170
+ "Post-merger HHI Contribution",
168
171
  ):
169
172
  raise ValueError(f"Invalid name for a concentration measure, {_value!r}.")
170
173
 
171
174
  threshold: float = field(kw_only=False, default=0.01)
172
175
 
173
- @threshold.validator # pyright: ignore
174
- def __tv(
176
+ @threshold.validator
177
+ def _tv(
175
178
  _instance: ConcentrationBoundary, _attribute: Attribute[float], _value: float, /
176
179
  ) -> None:
177
180
  if not 0 <= _value <= 1:
@@ -193,9 +196,9 @@ class ConcentrationBoundary:
193
196
  _conc_fn = gbfn.hhi_delta_boundary
194
197
  case "Combined share":
195
198
  _conc_fn = gbfn.combined_share_boundary
196
- case "Pre-merger HHI":
199
+ case "Pre-merger HHI Contribution":
197
200
  _conc_fn = gbfn.hhi_pre_contrib_boundary
198
- case "Post-merger HHI":
201
+ case "Post-merger HHI Contribution":
199
202
  _conc_fn = gbfn.hhi_post_contrib_boundary
200
203
 
201
204
  _boundary = _conc_fn(self.threshold, dps=self.precision)
@@ -221,13 +224,13 @@ class DiversionRatioBoundary:
221
224
  diversion_ratio: float = field(kw_only=False, default=0.065)
222
225
 
223
226
  @diversion_ratio.validator
224
- def __dvv(
227
+ def _dvv(
225
228
  _instance: DiversionRatioBoundary,
226
229
  _attribute: Attribute[float],
227
230
  _value: float,
228
231
  /,
229
232
  ) -> None:
230
- if not (isinstance(_value, float) and 0 <= _value <= 1):
233
+ if not (isinstance(_value, decimal.Decimal | float) and 0 <= _value <= 1):
231
234
  raise ValueError(
232
235
  "Margin-adjusted benchmark share ratio must lie between 0 and 1."
233
236
  )
@@ -260,7 +263,7 @@ class DiversionRatioBoundary:
260
263
  """
261
264
 
262
265
  @recapture_form.validator
263
- def __rsv(
266
+ def _rsv(
264
267
  _instance: DiversionRatioBoundary,
265
268
  _attribute: Attribute[RECForm],
266
269
  _value: RECForm,
@@ -909,7 +909,7 @@ def lerp[LerpT: (float, MPFloat, ArrayDouble, ArrayBIGINT)](
909
909
  elif _r == 1:
910
910
  return _x2
911
911
  else:
912
- return _r * _x2 + (1 - _r) * _x1 # pyright: ignore
912
+ return _r * _x2 + (1 - _r) * _x1
913
913
 
914
914
 
915
915
  def round_cust(
@@ -91,7 +91,7 @@ def hhi_delta_boundary_qdtr(_dh_val: float = 0.01, /) -> GuidelinesBoundaryCalla
91
91
 
92
92
  _hhi_bdry_area = 2 * (
93
93
  _s_nought
94
- + mp.quad(lambdify(_s_1, _hhi_bdry, "mpmath"), (_s_nought, 1 - _s_nought)) # pyright: ignore
94
+ + mp.quad(lambdify(_s_1, _hhi_bdry, "mpmath"), (_s_nought, 1 - _s_nought))
95
95
  )
96
96
 
97
97
  return GuidelinesBoundaryCallable(
@@ -159,7 +159,7 @@ def shrratio_boundary_qdtr_wtd_avg(
159
159
  2
160
160
  * (
161
161
  _s_naught
162
- + mp.quad(lambdify(_s_1, _bdry_func, "mpmath"), (_s_naught, _s_mid)) # pyright: ignore
162
+ + mp.quad(lambdify(_s_1, _bdry_func, "mpmath"), (_s_naught, _s_mid))
163
163
  )
164
164
  - (_s_mid**2 + _s_naught**2)
165
165
  )
@@ -189,7 +189,7 @@ def shrratio_boundary_qdtr_wtd_avg(
189
189
  ),
190
190
  (0, _s_mid),
191
191
  )
192
- ).real # pyright: ignore
192
+ ).real
193
193
  - _s_mid**2
194
194
  )
195
195
 
@@ -209,7 +209,7 @@ def shrratio_boundary_qdtr_wtd_avg(
209
209
 
210
210
  _bdry_func = solve(_bdry_eqn, _s_2)[0]
211
211
  _bdry_area = float(
212
- 2 * (mp.quad(lambdify(_s_1, _bdry_func, "mpmath"), (0, _s_mid))) # pyright: ignore
212
+ 2 * (mp.quad(lambdify(_s_1, _bdry_func, "mpmath"), (0, _s_mid)))
213
213
  - _s_mid**2
214
214
  )
215
215
 
@@ -6,20 +6,38 @@ https://github.com/numpy/numpy/issues/16313.
6
6
 
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  import concurrent.futures
10
12
  from collections.abc import Sequence
11
- from multiprocessing import cpu_count
12
13
  from typing import Literal
13
14
 
14
15
  import numpy as np
16
+ from attrs import Attribute, Converter, define, field
15
17
  from numpy.random import PCG64DXSM, Generator, SeedSequence
16
18
 
17
- from .. import VERSION, ArrayDouble # noqa: TID252
19
+ from .. import NTHREADS, VERSION, ArrayDouble, ArrayFloat, this_yaml # noqa: TID252
18
20
 
19
21
  __version__ = VERSION
20
22
 
21
- NTHREADS = 2 * cpu_count()
22
- DEFAULT_DIST_PARMS = np.array([0.0, 1.0], np.float64)
23
+ DEFAULT_DIST_PARMS: ArrayFloat = np.array([0.0, 1.0], float)
24
+ DEFAULT_BETA_DIST_PARMS: ArrayFloat = np.array([1.0, 1.0], float)
25
+
26
+
27
+ # Add yaml representer, constructor for SeedSequence
28
+ this_yaml.representer.add_representer(
29
+ SeedSequence,
30
+ lambda _r, _d: _r.represent_mapping(
31
+ "!SeedSequence",
32
+ {
33
+ _a: getattr(_d, _a)
34
+ for _a in ("entropy", "spawn_key", "pool_size", "n_children_spawned")
35
+ },
36
+ ),
37
+ )
38
+ this_yaml.constructor.add_constructor(
39
+ "!SeedSequence", lambda _c, _n, /: SeedSequence(**_c.construct_mapping(_n))
40
+ )
23
41
 
24
42
 
25
43
  def prng(_s: SeedSequence | None = None, /) -> np.random.Generator:
@@ -106,6 +124,21 @@ def gen_seed_seq_list_default(
106
124
  return [SeedSequence(_s, pool_size=8) for _s in generated_entropy[:_sseq_list_len]]
107
125
 
108
126
 
127
+ def _dist_parms_conv(_v: ArrayFloat, _i: MultithreadedRNG) -> ArrayFloat:
128
+ if not len(_v):
129
+ return {
130
+ "Beta": DEFAULT_BETA_DIST_PARMS,
131
+ "Dirichlet": np.ones(_i.values.shape[-1], float),
132
+ }.get(_i.dist_type, DEFAULT_DIST_PARMS)
133
+ elif isinstance(_v, Sequence | np.ndarray):
134
+ return np.asarray(_v, float)
135
+ else:
136
+ raise ValueError(
137
+ "Input, {_v!r} has invalid type. Must be None, Sequence of floats, or Numpy ndarray."
138
+ )
139
+
140
+
141
+ @define
109
142
  class MultithreadedRNG:
110
143
  """Fill given array on demand with pseudo-random numbers as specified.
111
144
 
@@ -114,98 +147,105 @@ class MultithreadedRNG:
114
147
  If a seed sequence is provided, it is used in a thread-safe way
115
148
  to generate repeatable i.i.d. draws. All arguments are validated
116
149
  before commencing multithreaded random number generation.
150
+ """
117
151
 
118
- Parameters
119
- ----------
120
- __out_array
121
- The output array to which generated data are written.
122
- Its dimensions define the size of the sample.
123
- dist_type
124
- Distribution for the generated random numbers
125
- dist_parms
126
- Parameters, if any, for tailoring random number generation
127
- seed_sequence
128
- SeedSequence object for generating repeatable draws.
129
- nthreads
130
- Number of threads to spawn for random number generation.
152
+ values: ArrayDouble = field(kw_only=False)
153
+ """Output array to which generated data are over-written
131
154
 
155
+ Array-length defines the number of i.i.d. (vector) draws.
132
156
  """
133
157
 
134
- def __init__(
135
- self,
136
- __out_array: ArrayDouble,
158
+ @values.validator
159
+ def _vsv(
160
+ _instance: MultithreadedRNG,
161
+ _attribute: Attribute[ArrayDouble],
162
+ _value: ArrayDouble,
137
163
  /,
138
- *,
139
- dist_type: Literal[
140
- "Beta", "Dirichlet", "Gaussian", "Normal", "Random", "Uniform"
141
- ] = "Uniform",
142
- dist_parms: ArrayDouble | None = DEFAULT_DIST_PARMS,
143
- seed_sequence: SeedSequence | None = None,
144
- nthreads: int = NTHREADS,
145
- ):
146
- self.thread_count = nthreads
147
-
148
- __seed_sequence = seed_sequence or SeedSequence(pool_size=8)
149
- self._random_generators = [
150
- prng(_t) for _t in __seed_sequence.spawn(self.thread_count)
151
- ]
152
-
153
- self.sample_sz = len(__out_array)
154
-
155
- if dist_type not in (_rdts := ("Beta", "Dirichlet", "Normal", "Uniform")):
156
- raise ValueError("Specified distribution must be one of {_rdts}")
157
-
158
- if not (dist_parms is None or isinstance(dist_parms, Sequence | np.ndarray)):
159
- raise ValueError(
160
- "When specified, distribution parameters must be a list, tuple or Numpy array"
161
- )
162
- if isinstance(dist_parms, Sequence):
163
- dist_parms = np.array(dist_parms)
164
- elif not dist_parms.any():
165
- dist_parms = None
166
-
167
- self.dist_type = dist_type
168
-
169
- if dist_parms is None or np.array_equal(dist_parms, DEFAULT_DIST_PARMS):
170
- match dist_type:
171
- case "Uniform":
172
- self.dist_type = "Random"
173
- case "Normal":
174
- self.dist_type = "Gaussian"
175
- case "Beta" | "Dirichlet":
176
- raise ValueError(
177
- f"parameter specification, {f'"{dist_parms}"'} "
178
- f"is invalid for specified distribution, f{'"{dist_type}"'}."
179
- )
180
- case _:
181
- raise ValueError(
182
- f"Invalid distributions specified, {f'"{dist_type}"'}."
183
- )
184
-
185
- elif dist_type == "Dirichlet":
186
- if len(dist_parms) != __out_array.shape[1]:
187
- raise ValueError(
188
- f"Insufficient shape parameters for requested Dirichlet sample "
189
- f"of size, {__out_array.shape}"
190
- )
164
+ ) -> None:
165
+ if not len(_value):
166
+ raise ValueError("Output array must at least be one dimension")
167
+
168
+ dist_type: Literal[
169
+ "Beta", "Dirichlet", "Gaussian", "Normal", "Random", "Uniform"
170
+ ] = field(default="Uniform")
171
+ """Distribution for the generated random numbers.
172
+
173
+ Default is "Uniform".
174
+ """
175
+
176
+ @dist_type.validator
177
+ def _dtv(
178
+ _instance: MultithreadedRNG, _attribute: Attribute[str], _value: str, /
179
+ ) -> None:
180
+ if _value not in (
181
+ _rdts := ("Beta", "Dirichlet", "Gaussian", "Normal", "Random", "Uniform")
182
+ ):
183
+ raise ValueError(f"Specified distribution must be one of {_rdts}")
184
+
185
+ dist_parms: ArrayFloat = field(
186
+ converter=Converter(_dist_parms_conv, takes_self=True) # type: ignore
187
+ )
188
+ """Parameters, if any, for tailoring random number generation
189
+ """
191
190
 
192
- elif (_lrdp := len(dist_parms)) != 2:
193
- raise ValueError(f"Expected 2 parameters, got, {_lrdp}")
191
+ @dist_parms.default
192
+ def _dpd(_instance: MultithreadedRNG) -> ArrayFloat:
193
+ return {
194
+ "Beta": DEFAULT_BETA_DIST_PARMS,
195
+ "Dirichlet": np.ones(_instance.values.shape[-1], float),
196
+ }.get(_instance.dist_type, DEFAULT_DIST_PARMS)
197
+
198
+ @dist_parms.validator
199
+ def _dpv(
200
+ _instance: MultithreadedRNG,
201
+ _attribute: Attribute[ArrayFloat],
202
+ _value: ArrayFloat,
203
+ /,
204
+ ) -> None:
205
+ if (
206
+ _instance.dist_type != "Dirichlet"
207
+ and (_lrdp := len(_value)) != (_trdp := 2)
208
+ ) or (
209
+ _instance.dist_type == "Dirichlet"
210
+ and (_lrdp := len(_value)) != (_trdp := _instance.values.shape[1])
211
+ ):
212
+ raise ValueError(f"Expected {_trdp} parameters, got, {_lrdp}")
194
213
 
195
- self.dist_parms = dist_parms
214
+ elif _instance.dist_type in ("Beta", "Dirichlet") and (_value <= 0.0).any():
215
+ raise ValueError("Shape and location parameters must be strictly positive")
196
216
 
197
- self.values = __out_array
198
- self.executor = concurrent.futures.ThreadPoolExecutor(self.thread_count)
217
+ seed_sequence: SeedSequence | None = field(default=None)
218
+ """Seed sequence for generating random numbers."""
199
219
 
200
- self.step_size = (len(self.values) / self.thread_count).__ceil__()
220
+ nthreads: int = field(default=NTHREADS)
221
+ """Number of threads to spawn for random number generation."""
201
222
 
202
223
  def fill(self) -> None:
203
- """Fill the provided output array with random numbers as specified."""
224
+ """Fill the provided output array with random number draws as specified."""
225
+
226
+ if not len(self.dist_parms) or np.array_equal(
227
+ self.dist_parms, DEFAULT_DIST_PARMS
228
+ ):
229
+ if self.dist_type == "Uniform":
230
+ _dist_type = "Random"
231
+ elif self.dist_type == "Normal":
232
+ _dist_type = "Gaussian"
233
+ else:
234
+ _dist_type = self.dist_type
235
+
236
+ _step_size = (len(self.values) / self.nthreads).__ceil__()
237
+ # int; function gives float unsuitable for slicing
238
+
239
+ _seed_sequence = self.seed_sequence or SeedSequence(pool_size=8)
240
+
241
+ _random_generators = tuple(
242
+ prng(_t) for _t in _seed_sequence.spawn(self.nthreads)
243
+ )
204
244
 
205
245
  def _fill(
206
246
  _rng: np.random.Generator,
207
247
  _dist_type: str,
208
- _dist_parms: ArrayDouble,
248
+ _dist_parms: ArrayFloat,
209
249
  _out: ArrayDouble,
210
250
  _first: int,
211
251
  _last: int,
@@ -213,37 +253,35 @@ class MultithreadedRNG:
213
253
  ) -> None:
214
254
  _sz: tuple[int, ...] = _out[_first:_last].shape
215
255
  match _dist_type:
216
- case "Random":
217
- _rng.random(out=_out[_first:_last])
218
- case "Uniform":
219
- _uni_l, _uni_h = _dist_parms
220
- _out[_first:_last] = _rng.uniform(_uni_l, _uni_h, size=_sz)
221
- case "Dirichlet":
222
- _out[_first:_last] = _rng.dirichlet(_dist_parms, size=_sz[:-1])
223
256
  case "Beta":
224
257
  _shape_a, _shape_b = _dist_parms
225
258
  _out[_first:_last] = _rng.beta(_shape_a, _shape_b, size=_sz)
259
+ case "Dirichlet":
260
+ _out[_first:_last] = _rng.dirichlet(_dist_parms, size=_sz[:-1])
261
+ case "Gaussian":
262
+ _rng.standard_normal(out=_out[_first:_last])
226
263
  case "Normal":
227
264
  _mu, _sigma = _dist_parms
228
265
  _out[_first:_last] = _rng.normal(_mu, _sigma, size=_sz)
266
+ case "Random":
267
+ _rng.random(out=_out[_first:_last])
268
+ case "Uniform":
269
+ _uni_l, _uni_h = _dist_parms
270
+ _out[_first:_last] = _rng.uniform(_uni_l, _uni_h, size=_sz)
229
271
  case _:
230
- _rng.standard_normal(out=_out[_first:_last])
231
-
232
- futures = {}
233
- for i in range(self.thread_count):
234
- _range_first = i * self.step_size
235
- _range_last = min(len(self.values), (i + 1) * self.step_size)
236
- args = (
237
- _fill,
238
- self._random_generators[i],
239
- self.dist_type,
240
- self.dist_parms,
241
- self.values,
242
- _range_first,
243
- _range_last,
244
- )
245
- futures[self.executor.submit(*args)] = i # type: ignore
246
- concurrent.futures.wait(futures)
247
-
248
- def __del__(self) -> None:
249
- self.executor.shutdown(False)
272
+ "Unreachable. The validator would have rejected this as invalid."
273
+
274
+ with concurrent.futures.ThreadPoolExecutor(self.nthreads) as _executor:
275
+ for i in range(self.nthreads):
276
+ _range_first = i * _step_size
277
+ _range_last = min(len(self.values), (i + 1) * _step_size)
278
+
279
+ _executor.submit(
280
+ _fill,
281
+ _random_generators[i],
282
+ _dist_type,
283
+ self.dist_parms,
284
+ self.values,
285
+ _range_first,
286
+ _range_last,
287
+ )
@@ -45,7 +45,7 @@ with warnings.catch_warnings():
45
45
  ])
46
46
 
47
47
  mgn_kde = stats.gaussian_kde(mgn_data_obs, weights=mgn_data_wts, bw_method="silverman")
48
- mgn_kde.set_bandwidth(bw_method=mgn_kde.factor / 3.0) # pyright: ignore
48
+ mgn_kde.set_bandwidth(bw_method=mgn_kde.factor / 3.0)
49
49
 
50
50
  mgn_ax.plot(
51
51
  (_xv := np.linspace(0, BIN_COUNT, 10**5) / BIN_COUNT),