mergeron 2025.739265.2__py3-none-any.whl → 2025.739290.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mergeron might be problematic. Click here for more details.
- mergeron/__init__.py +67 -2
- mergeron/core/guidelines_boundaries.py +85 -21
- mergeron/core/pseudorandom_numbers.py +61 -51
- mergeron/gen/__init__.py +222 -84
- mergeron/gen/data_generation.py +143 -182
- mergeron/gen/data_generation_functions.py +68 -118
- mergeron/gen/enforcement_stats.py +30 -6
- mergeron/gen/upp_tests.py +6 -7
- {mergeron-2025.739265.2.dist-info → mergeron-2025.739290.1.dist-info}/METADATA +2 -1
- {mergeron-2025.739265.2.dist-info → mergeron-2025.739290.1.dist-info}/RECORD +11 -11
- {mergeron-2025.739265.2.dist-info → mergeron-2025.739290.1.dist-info}/WHEEL +0 -0
mergeron/__init__.py
CHANGED
|
@@ -1,18 +1,26 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import enum
|
|
4
|
+
from multiprocessing import cpu_count
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import Literal
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
9
|
+
from numpy.random import SeedSequence
|
|
8
10
|
from numpy.typing import NDArray
|
|
11
|
+
from ruamel import yaml
|
|
9
12
|
|
|
10
13
|
_PKG_NAME: str = Path(__file__).parent.stem
|
|
11
14
|
|
|
12
|
-
VERSION = "2025.
|
|
15
|
+
VERSION = "2025.739290.1"
|
|
13
16
|
|
|
14
17
|
__version__ = VERSION
|
|
15
18
|
|
|
19
|
+
this_yaml = yaml.YAML(typ="safe", pure=True)
|
|
20
|
+
this_yaml.constructor.deep_construct = True
|
|
21
|
+
this_yaml.indent(mapping=2, sequence=4, offset=2)
|
|
22
|
+
|
|
23
|
+
|
|
16
24
|
DATA_DIR: Path = Path.home() / _PKG_NAME
|
|
17
25
|
"""
|
|
18
26
|
Defines a subdirectory named for this package in the user's home path.
|
|
@@ -22,6 +30,13 @@ If the subdirectory doesn't exist, it is created on package invocation.
|
|
|
22
30
|
if not DATA_DIR.is_dir():
|
|
23
31
|
DATA_DIR.mkdir(parents=False)
|
|
24
32
|
|
|
33
|
+
DEFAULT_REC_RATIO = 0.85
|
|
34
|
+
|
|
35
|
+
EMPTY_ARRAYDOUBLE = np.array([], float)
|
|
36
|
+
EMPTY_ARRAYINT = np.array([], int)
|
|
37
|
+
|
|
38
|
+
NTHREADS = 2 * cpu_count()
|
|
39
|
+
|
|
25
40
|
np.set_printoptions(precision=24, floatmode="fixed")
|
|
26
41
|
|
|
27
42
|
type HMGPubYear = Literal[1982, 1984, 1992, 2010, 2023]
|
|
@@ -33,7 +48,39 @@ type ArrayINT = NDArray[np.intp]
|
|
|
33
48
|
type ArrayDouble = NDArray[np.float64]
|
|
34
49
|
type ArrayBIGINT = NDArray[np.int64]
|
|
35
50
|
|
|
36
|
-
|
|
51
|
+
## Add yaml representer, constructor for NoneType
|
|
52
|
+
(_, _) = (
|
|
53
|
+
this_yaml.representer.add_representer(
|
|
54
|
+
type(None), lambda _r, _d: _r.represent_scalar("!None", "none")
|
|
55
|
+
),
|
|
56
|
+
this_yaml.constructor.add_constructor("!None", lambda _c, _n, /: None),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
## Add yaml representer, constructor for ndarray
|
|
60
|
+
(_, _) = (
|
|
61
|
+
this_yaml.representer.add_representer(
|
|
62
|
+
np.ndarray,
|
|
63
|
+
lambda _r, _d: _r.represent_sequence("!ndarray", (_d.tolist(), _d.dtype.str)),
|
|
64
|
+
),
|
|
65
|
+
this_yaml.constructor.add_constructor(
|
|
66
|
+
"!ndarray", lambda _c, _n, /: np.array(*_c.construct_sequence(_n))
|
|
67
|
+
),
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Add yaml representer, constructor for SeedSequence
|
|
71
|
+
this_yaml.representer.add_representer(
|
|
72
|
+
SeedSequence,
|
|
73
|
+
lambda _r, _d: _r.represent_mapping(
|
|
74
|
+
"!SeedSequence",
|
|
75
|
+
{
|
|
76
|
+
_a: getattr(_d, _a)
|
|
77
|
+
for _a in ("entropy", "spawn_key", "pool_size", "n_children_spawned")
|
|
78
|
+
},
|
|
79
|
+
),
|
|
80
|
+
)
|
|
81
|
+
this_yaml.constructor.add_constructor(
|
|
82
|
+
"!SeedSequence", lambda _c, _n, /: SeedSequence(**_c.construct_mapping(_n))
|
|
83
|
+
)
|
|
37
84
|
|
|
38
85
|
|
|
39
86
|
@enum.unique
|
|
@@ -63,3 +110,21 @@ class UPPAggrSelector(enum.StrEnum):
|
|
|
63
110
|
OSA = "own-share weighted average"
|
|
64
111
|
OSD = "own-share weighted distance"
|
|
65
112
|
OSG = "own-share weighted geometric mean"
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
for _typ in (RECForm, UPPAggrSelector):
|
|
116
|
+
# NOTE: If additional enums are defined in this module,
|
|
117
|
+
# add themn to the list above
|
|
118
|
+
|
|
119
|
+
_, _ = (
|
|
120
|
+
this_yaml.representer.add_representer(
|
|
121
|
+
_typ,
|
|
122
|
+
lambda _r, _d: _r.represent_scalar(f"!{_d.__class__.__name__}", _d.name),
|
|
123
|
+
),
|
|
124
|
+
this_yaml.constructor.add_constructor(
|
|
125
|
+
f"!{_typ.__name__}",
|
|
126
|
+
lambda _c, _n, /: getattr(
|
|
127
|
+
globals().get(_n.tag.lstrip("!")), _c.construct_scalar(_n)
|
|
128
|
+
),
|
|
129
|
+
),
|
|
130
|
+
)
|
|
@@ -13,6 +13,7 @@ from typing import Literal
|
|
|
13
13
|
import numpy as np
|
|
14
14
|
from attrs import Attribute, field, frozen, validators
|
|
15
15
|
from mpmath import mp, mpf # type: ignore
|
|
16
|
+
from ruamel import yaml
|
|
16
17
|
|
|
17
18
|
from .. import ( # noqa: TID252
|
|
18
19
|
DEFAULT_REC_RATIO,
|
|
@@ -21,6 +22,7 @@ from .. import ( # noqa: TID252
|
|
|
21
22
|
HMGPubYear,
|
|
22
23
|
RECForm,
|
|
23
24
|
UPPAggrSelector,
|
|
25
|
+
this_yaml,
|
|
24
26
|
)
|
|
25
27
|
from . import guidelines_boundary_functions as gbfn
|
|
26
28
|
|
|
@@ -42,6 +44,7 @@ class HMGThresholds:
|
|
|
42
44
|
ipr: float
|
|
43
45
|
|
|
44
46
|
|
|
47
|
+
@this_yaml.register_class
|
|
45
48
|
@frozen
|
|
46
49
|
class GuidelinesThresholds:
|
|
47
50
|
"""
|
|
@@ -62,7 +65,7 @@ class GuidelinesThresholds:
|
|
|
62
65
|
Year of publication of the Guidelines
|
|
63
66
|
"""
|
|
64
67
|
|
|
65
|
-
safeharbor: HMGThresholds = field(kw_only=True, default=None)
|
|
68
|
+
safeharbor: HMGThresholds = field(kw_only=True, default=None, init=False)
|
|
66
69
|
"""
|
|
67
70
|
Negative presumption quantified on various measures
|
|
68
71
|
|
|
@@ -70,7 +73,7 @@ class GuidelinesThresholds:
|
|
|
70
73
|
diversion ratio limit, CMCR, and IPR
|
|
71
74
|
"""
|
|
72
75
|
|
|
73
|
-
presumption: HMGThresholds = field(kw_only=True, default=None)
|
|
76
|
+
presumption: HMGThresholds = field(kw_only=True, default=None, init=False)
|
|
74
77
|
"""
|
|
75
78
|
Presumption of harm defined in HMG
|
|
76
79
|
|
|
@@ -78,7 +81,7 @@ class GuidelinesThresholds:
|
|
|
78
81
|
diversion ratio limit, CMCR, and IPR
|
|
79
82
|
"""
|
|
80
83
|
|
|
81
|
-
imputed_presumption: HMGThresholds = field(kw_only=True, default=None)
|
|
84
|
+
imputed_presumption: HMGThresholds = field(kw_only=True, default=None, init=False)
|
|
82
85
|
"""
|
|
83
86
|
Presumption of harm imputed from guidelines
|
|
84
87
|
|
|
@@ -147,31 +150,51 @@ class GuidelinesThresholds:
|
|
|
147
150
|
),
|
|
148
151
|
)
|
|
149
152
|
|
|
153
|
+
@classmethod
|
|
154
|
+
def to_yaml(
|
|
155
|
+
cls, _r: yaml.representer.SafeRepresenter, _d: GuidelinesThresholds
|
|
156
|
+
) -> yaml.MappingNode:
|
|
157
|
+
_ret: yaml.MappingNode = _r.represent_mapping(
|
|
158
|
+
f"!{cls.__name__}",
|
|
159
|
+
{_a.name: getattr(_d, _a.name) for _a in _d.__attrs_attrs__},
|
|
160
|
+
)
|
|
161
|
+
return _ret
|
|
162
|
+
|
|
163
|
+
@classmethod
|
|
164
|
+
def from_yaml(
|
|
165
|
+
cls, _c: yaml.constructor.SafeConstructor, _n: yaml.MappingNode
|
|
166
|
+
) -> GuidelinesThresholds:
|
|
167
|
+
return cls(**_c.construct_mapping(_n))
|
|
150
168
|
|
|
169
|
+
|
|
170
|
+
@this_yaml.register_class
|
|
151
171
|
@frozen
|
|
152
172
|
class ConcentrationBoundary:
|
|
153
173
|
"""Concentration parameters, boundary coordinates, and area under concentration boundary."""
|
|
154
174
|
|
|
155
175
|
measure_name: Literal[
|
|
156
|
-
"ΔHHI",
|
|
176
|
+
"ΔHHI",
|
|
177
|
+
"Combined share",
|
|
178
|
+
"Pre-merger HHI Contribution",
|
|
179
|
+
"Post-merger HHI Contribution",
|
|
157
180
|
] = field(kw_only=False, default="ΔHHI")
|
|
158
181
|
|
|
159
182
|
@measure_name.validator
|
|
160
|
-
def
|
|
183
|
+
def _mnv(
|
|
161
184
|
_instance: ConcentrationBoundary, _attribute: Attribute[str], _value: str, /
|
|
162
185
|
) -> None:
|
|
163
186
|
if _value not in (
|
|
164
187
|
"ΔHHI",
|
|
165
188
|
"Combined share",
|
|
166
|
-
"Pre-merger HHI",
|
|
167
|
-
"Post-merger HHI",
|
|
189
|
+
"Pre-merger HHI Contribution",
|
|
190
|
+
"Post-merger HHI Contribution",
|
|
168
191
|
):
|
|
169
192
|
raise ValueError(f"Invalid name for a concentration measure, {_value!r}.")
|
|
170
193
|
|
|
171
194
|
threshold: float = field(kw_only=False, default=0.01)
|
|
172
195
|
|
|
173
196
|
@threshold.validator
|
|
174
|
-
def
|
|
197
|
+
def _tv(
|
|
175
198
|
_instance: ConcentrationBoundary, _attribute: Attribute[float], _value: float, /
|
|
176
199
|
) -> None:
|
|
177
200
|
if not 0 <= _value <= 1:
|
|
@@ -181,28 +204,49 @@ class ConcentrationBoundary:
|
|
|
181
204
|
kw_only=False, default=5, validator=validators.instance_of(int)
|
|
182
205
|
)
|
|
183
206
|
|
|
184
|
-
coordinates: ArrayDouble = field(init=False, kw_only=True)
|
|
185
|
-
"""Market-share pairs as Cartesian coordinates of points on the concentration boundary."""
|
|
186
|
-
|
|
187
207
|
area: float = field(init=False, kw_only=True)
|
|
188
208
|
"""Area under the concentration boundary."""
|
|
189
209
|
|
|
210
|
+
coordinates: ArrayDouble = field(init=False, kw_only=True)
|
|
211
|
+
"""Market-share pairs as Cartesian coordinates of points on the concentration boundary."""
|
|
212
|
+
|
|
190
213
|
def __attrs_post_init__(self, /) -> None:
|
|
191
214
|
match self.measure_name:
|
|
192
215
|
case "ΔHHI":
|
|
193
216
|
_conc_fn = gbfn.hhi_delta_boundary
|
|
194
217
|
case "Combined share":
|
|
195
218
|
_conc_fn = gbfn.combined_share_boundary
|
|
196
|
-
case "Pre-merger HHI":
|
|
219
|
+
case "Pre-merger HHI Contribution":
|
|
197
220
|
_conc_fn = gbfn.hhi_pre_contrib_boundary
|
|
198
|
-
case "Post-merger HHI":
|
|
221
|
+
case "Post-merger HHI Contribution":
|
|
199
222
|
_conc_fn = gbfn.hhi_post_contrib_boundary
|
|
200
223
|
|
|
201
224
|
_boundary = _conc_fn(self.threshold, dps=self.precision)
|
|
202
|
-
object.__setattr__(self, "coordinates", _boundary.coordinates)
|
|
203
225
|
object.__setattr__(self, "area", _boundary.area)
|
|
226
|
+
object.__setattr__(self, "coordinates", _boundary.coordinates)
|
|
227
|
+
|
|
228
|
+
@classmethod
|
|
229
|
+
def to_yaml(
|
|
230
|
+
cls, _r: yaml.representer.SafeRepresenter, _d: ConcentrationBoundary
|
|
231
|
+
) -> yaml.MappingNode:
|
|
232
|
+
_ret: yaml.MappingNode = _r.represent_mapping(
|
|
233
|
+
f"!{cls.__name__}",
|
|
234
|
+
{
|
|
235
|
+
_a.name: getattr(_d, _a.name)
|
|
236
|
+
for _a in _d.__attrs_attrs__
|
|
237
|
+
if _a.name not in ("area", "coordinates")
|
|
238
|
+
},
|
|
239
|
+
)
|
|
240
|
+
return _ret
|
|
204
241
|
|
|
242
|
+
@classmethod
|
|
243
|
+
def from_yaml(
|
|
244
|
+
cls, _c: yaml.constructor.SafeConstructor, _n: yaml.MappingNode
|
|
245
|
+
) -> ConcentrationBoundary:
|
|
246
|
+
return cls(**_c.construct_mapping(_n))
|
|
205
247
|
|
|
248
|
+
|
|
249
|
+
@this_yaml.register_class
|
|
206
250
|
@frozen
|
|
207
251
|
class DiversionRatioBoundary:
|
|
208
252
|
"""
|
|
@@ -221,13 +265,13 @@ class DiversionRatioBoundary:
|
|
|
221
265
|
diversion_ratio: float = field(kw_only=False, default=0.065)
|
|
222
266
|
|
|
223
267
|
@diversion_ratio.validator
|
|
224
|
-
def
|
|
268
|
+
def _dvv(
|
|
225
269
|
_instance: DiversionRatioBoundary,
|
|
226
270
|
_attribute: Attribute[float],
|
|
227
271
|
_value: float,
|
|
228
272
|
/,
|
|
229
273
|
) -> None:
|
|
230
|
-
if not (isinstance(_value, float) and 0 <= _value <= 1):
|
|
274
|
+
if not (isinstance(_value, decimal.Decimal | float) and 0 <= _value <= 1):
|
|
231
275
|
raise ValueError(
|
|
232
276
|
"Margin-adjusted benchmark share ratio must lie between 0 and 1."
|
|
233
277
|
)
|
|
@@ -260,7 +304,7 @@ class DiversionRatioBoundary:
|
|
|
260
304
|
"""
|
|
261
305
|
|
|
262
306
|
@recapture_form.validator
|
|
263
|
-
def
|
|
307
|
+
def _rsv(
|
|
264
308
|
_instance: DiversionRatioBoundary,
|
|
265
309
|
_attribute: Attribute[RECForm],
|
|
266
310
|
_value: RECForm,
|
|
@@ -307,12 +351,12 @@ class DiversionRatioBoundary:
|
|
|
307
351
|
|
|
308
352
|
"""
|
|
309
353
|
|
|
310
|
-
coordinates: ArrayDouble = field(init=False, kw_only=True)
|
|
311
|
-
"""Market-share pairs as Cartesian coordinates of points on the diversion ratio boundary."""
|
|
312
|
-
|
|
313
354
|
area: float = field(init=False, kw_only=True)
|
|
314
355
|
"""Area under the diversion ratio boundary."""
|
|
315
356
|
|
|
357
|
+
coordinates: ArrayDouble = field(init=False, kw_only=True)
|
|
358
|
+
"""Market-share pairs as Cartesian coordinates of points on the diversion ratio boundary."""
|
|
359
|
+
|
|
316
360
|
def __attrs_post_init__(self, /) -> None:
|
|
317
361
|
_share_ratio = critical_share_ratio(
|
|
318
362
|
self.diversion_ratio, r_bar=self.recapture_ratio
|
|
@@ -356,8 +400,28 @@ class DiversionRatioBoundary:
|
|
|
356
400
|
_upp_agg_kwargs |= {"agg_method": _aggregator, "weighting": _wgt_type}
|
|
357
401
|
|
|
358
402
|
_boundary = _upp_agg_fn(_share_ratio, self.recapture_ratio, **_upp_agg_kwargs)
|
|
359
|
-
object.__setattr__(self, "coordinates", _boundary.coordinates)
|
|
360
403
|
object.__setattr__(self, "area", _boundary.area)
|
|
404
|
+
object.__setattr__(self, "coordinates", _boundary.coordinates)
|
|
405
|
+
|
|
406
|
+
@classmethod
|
|
407
|
+
def to_yaml(
|
|
408
|
+
cls, _r: yaml.representer.SafeRepresenter, _d: DiversionRatioBoundary
|
|
409
|
+
) -> yaml.MappingNode:
|
|
410
|
+
_ret: yaml.MappingNode = _r.represent_mapping(
|
|
411
|
+
f"!{cls.__name__}",
|
|
412
|
+
{
|
|
413
|
+
_a.name: getattr(_d, _a.name)
|
|
414
|
+
for _a in _d.__attrs_attrs__
|
|
415
|
+
if _a.name not in ("area", "coordinates")
|
|
416
|
+
},
|
|
417
|
+
)
|
|
418
|
+
return _ret
|
|
419
|
+
|
|
420
|
+
@classmethod
|
|
421
|
+
def from_yaml(
|
|
422
|
+
cls, _c: yaml.constructor.SafeConstructor, _n: yaml.MappingNode
|
|
423
|
+
) -> DiversionRatioBoundary:
|
|
424
|
+
return cls(**_c.construct_mapping(_n))
|
|
361
425
|
|
|
362
426
|
|
|
363
427
|
def guppi_from_delta(
|
|
@@ -10,20 +10,18 @@ from __future__ import annotations
|
|
|
10
10
|
|
|
11
11
|
import concurrent.futures
|
|
12
12
|
from collections.abc import Sequence
|
|
13
|
-
from multiprocessing import cpu_count
|
|
14
13
|
from typing import Literal
|
|
15
14
|
|
|
16
15
|
import numpy as np
|
|
17
|
-
from attrs import Attribute, define, field
|
|
16
|
+
from attrs import Attribute, Converter, define, field
|
|
18
17
|
from numpy.random import PCG64DXSM, Generator, SeedSequence
|
|
19
18
|
|
|
20
|
-
from .. import VERSION, ArrayDouble # noqa: TID252
|
|
19
|
+
from .. import NTHREADS, VERSION, ArrayDouble, ArrayFloat # noqa: TID252
|
|
21
20
|
|
|
22
21
|
__version__ = VERSION
|
|
23
22
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
DEFAULT_BETA_DIST_PARMS: ArrayDouble = np.array([1.0, 1.0], float)
|
|
23
|
+
DEFAULT_DIST_PARMS: ArrayFloat = np.array([0.0, 1.0], float)
|
|
24
|
+
DEFAULT_BETA_DIST_PARMS: ArrayFloat = np.array([1.0, 1.0], float)
|
|
27
25
|
|
|
28
26
|
|
|
29
27
|
def prng(_s: SeedSequence | None = None, /) -> np.random.Generator:
|
|
@@ -110,6 +108,20 @@ def gen_seed_seq_list_default(
|
|
|
110
108
|
return [SeedSequence(_s, pool_size=8) for _s in generated_entropy[:_sseq_list_len]]
|
|
111
109
|
|
|
112
110
|
|
|
111
|
+
def _dist_parms_conv(_v: ArrayFloat, _i: MultithreadedRNG) -> ArrayFloat:
|
|
112
|
+
if not len(_v):
|
|
113
|
+
return {
|
|
114
|
+
"Beta": DEFAULT_BETA_DIST_PARMS,
|
|
115
|
+
"Dirichlet": np.ones(_i.values.shape[-1], float),
|
|
116
|
+
}.get(_i.dist_type, DEFAULT_DIST_PARMS)
|
|
117
|
+
elif isinstance(_v, Sequence | np.ndarray):
|
|
118
|
+
return np.asarray(_v, float)
|
|
119
|
+
else:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
"Input, {_v!r} has invalid type. Must be None, Sequence of floats, or Numpy ndarray."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
113
125
|
@define
|
|
114
126
|
class MultithreadedRNG:
|
|
115
127
|
"""Fill given array on demand with pseudo-random numbers as specified.
|
|
@@ -121,22 +133,32 @@ class MultithreadedRNG:
|
|
|
121
133
|
before commencing multithreaded random number generation.
|
|
122
134
|
"""
|
|
123
135
|
|
|
124
|
-
values: ArrayDouble = field(kw_only=False
|
|
136
|
+
values: ArrayDouble = field(kw_only=False)
|
|
125
137
|
"""Output array to which generated data are over-written
|
|
126
138
|
|
|
127
139
|
Array-length defines the number of i.i.d. (vector) draws.
|
|
128
140
|
"""
|
|
129
141
|
|
|
142
|
+
@values.validator
|
|
143
|
+
def _vsv(
|
|
144
|
+
_instance: MultithreadedRNG,
|
|
145
|
+
_attribute: Attribute[ArrayDouble],
|
|
146
|
+
_value: ArrayDouble,
|
|
147
|
+
/,
|
|
148
|
+
) -> None:
|
|
149
|
+
if not len(_value):
|
|
150
|
+
raise ValueError("Output array must at least be one dimension")
|
|
151
|
+
|
|
130
152
|
dist_type: Literal[
|
|
131
153
|
"Beta", "Dirichlet", "Gaussian", "Normal", "Random", "Uniform"
|
|
132
|
-
] = field(
|
|
154
|
+
] = field(default="Uniform")
|
|
133
155
|
"""Distribution for the generated random numbers.
|
|
134
156
|
|
|
135
157
|
Default is "Uniform".
|
|
136
158
|
"""
|
|
137
159
|
|
|
138
160
|
@dist_type.validator
|
|
139
|
-
def
|
|
161
|
+
def _dtv(
|
|
140
162
|
_instance: MultithreadedRNG, _attribute: Attribute[str], _value: str, /
|
|
141
163
|
) -> None:
|
|
142
164
|
if _value not in (
|
|
@@ -144,60 +166,48 @@ class MultithreadedRNG:
|
|
|
144
166
|
):
|
|
145
167
|
raise ValueError(f"Specified distribution must be one of {_rdts}")
|
|
146
168
|
|
|
147
|
-
dist_parms:
|
|
169
|
+
dist_parms: ArrayFloat = field(
|
|
170
|
+
converter=Converter(_dist_parms_conv, takes_self=True) # type: ignore
|
|
171
|
+
)
|
|
148
172
|
"""Parameters, if any, for tailoring random number generation
|
|
149
173
|
"""
|
|
150
174
|
|
|
175
|
+
@dist_parms.default
|
|
176
|
+
def _dpd(_instance: MultithreadedRNG) -> ArrayFloat:
|
|
177
|
+
return {
|
|
178
|
+
"Beta": DEFAULT_BETA_DIST_PARMS,
|
|
179
|
+
"Dirichlet": np.ones(_instance.values.shape[-1], float),
|
|
180
|
+
}.get(_instance.dist_type, DEFAULT_DIST_PARMS)
|
|
181
|
+
|
|
151
182
|
@dist_parms.validator
|
|
152
|
-
def
|
|
153
|
-
_instance: MultithreadedRNG,
|
|
183
|
+
def _dpv(
|
|
184
|
+
_instance: MultithreadedRNG,
|
|
185
|
+
_attribute: Attribute[ArrayFloat],
|
|
186
|
+
_value: ArrayFloat,
|
|
187
|
+
/,
|
|
154
188
|
) -> None:
|
|
155
|
-
if
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
189
|
+
if (
|
|
190
|
+
_instance.dist_type != "Dirichlet"
|
|
191
|
+
and (_lrdp := len(_value)) != (_trdp := 2)
|
|
192
|
+
) or (
|
|
193
|
+
_instance.dist_type == "Dirichlet"
|
|
194
|
+
and (_lrdp := len(_value)) != (_trdp := _instance.values.shape[1])
|
|
195
|
+
):
|
|
196
|
+
raise ValueError(f"Expected {_trdp} parameters, got, {_lrdp}")
|
|
160
197
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
and (_lrdp := len(_value)) != (_trdp := 2)
|
|
164
|
-
) or (
|
|
165
|
-
_instance.dist_type == "Dirichlet"
|
|
166
|
-
and (_lrdp := len(_value)) != (_trdp := _instance.values.shape[1])
|
|
167
|
-
):
|
|
168
|
-
raise ValueError(f"Expected {_trdp} parameters, got, {_lrdp}")
|
|
169
|
-
|
|
170
|
-
elif (
|
|
171
|
-
_instance.dist_type in ("Beta", "Dirichlet")
|
|
172
|
-
and (np.array(_value) <= 0.0).any()
|
|
173
|
-
):
|
|
174
|
-
raise ValueError(
|
|
175
|
-
"Shape and location parameters must be strictly positive"
|
|
176
|
-
)
|
|
198
|
+
elif _instance.dist_type in ("Beta", "Dirichlet") and (_value <= 0.0).any():
|
|
199
|
+
raise ValueError("Shape and location parameters must be strictly positive")
|
|
177
200
|
|
|
178
|
-
seed_sequence: SeedSequence | None = field(
|
|
201
|
+
seed_sequence: SeedSequence | None = field(default=None)
|
|
179
202
|
"""Seed sequence for generating random numbers."""
|
|
180
203
|
|
|
181
|
-
nthreads: int = field(
|
|
204
|
+
nthreads: int = field(default=NTHREADS)
|
|
182
205
|
"""Number of threads to spawn for random number generation."""
|
|
183
206
|
|
|
184
207
|
def fill(self) -> None:
|
|
185
208
|
"""Fill the provided output array with random number draws as specified."""
|
|
186
209
|
|
|
187
|
-
if (
|
|
188
|
-
self.dist_parms is None
|
|
189
|
-
or not (
|
|
190
|
-
_dist_parms := np.array(self.dist_parms) # one-shot conversion
|
|
191
|
-
).any()
|
|
192
|
-
):
|
|
193
|
-
if self.dist_type == "Beta":
|
|
194
|
-
_dist_parms = DEFAULT_BETA_DIST_PARMS
|
|
195
|
-
elif self.dist_type == "Dirichlet":
|
|
196
|
-
_dist_parms = np.ones(self.values.shape[1], float)
|
|
197
|
-
else:
|
|
198
|
-
_dist_parms = DEFAULT_DIST_PARMS
|
|
199
|
-
|
|
200
|
-
if self.dist_parms is None or np.array_equal(
|
|
210
|
+
if not len(self.dist_parms) or np.array_equal(
|
|
201
211
|
self.dist_parms, DEFAULT_DIST_PARMS
|
|
202
212
|
):
|
|
203
213
|
if self.dist_type == "Uniform":
|
|
@@ -219,7 +229,7 @@ class MultithreadedRNG:
|
|
|
219
229
|
def _fill(
|
|
220
230
|
_rng: np.random.Generator,
|
|
221
231
|
_dist_type: str,
|
|
222
|
-
_dist_parms:
|
|
232
|
+
_dist_parms: ArrayFloat,
|
|
223
233
|
_out: ArrayDouble,
|
|
224
234
|
_first: int,
|
|
225
235
|
_last: int,
|
|
@@ -254,7 +264,7 @@ class MultithreadedRNG:
|
|
|
254
264
|
_fill,
|
|
255
265
|
_random_generators[i],
|
|
256
266
|
_dist_type,
|
|
257
|
-
|
|
267
|
+
self.dist_parms,
|
|
258
268
|
self.values,
|
|
259
269
|
_range_first,
|
|
260
270
|
_range_last,
|