mergeron 2025.739265.2__tar.gz → 2025.739290.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mergeron might be problematic. Click here for more details.

Files changed (24) hide show
  1. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/PKG-INFO +2 -1
  2. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/pyproject.toml +2 -1
  3. mergeron-2025.739290.1/src/mergeron/__init__.py +130 -0
  4. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/core/guidelines_boundaries.py +85 -21
  5. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/core/pseudorandom_numbers.py +61 -51
  6. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/gen/__init__.py +222 -84
  7. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/gen/data_generation.py +143 -182
  8. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/gen/data_generation_functions.py +68 -118
  9. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/gen/enforcement_stats.py +30 -6
  10. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/gen/upp_tests.py +6 -7
  11. mergeron-2025.739265.2/src/mergeron/__init__.py +0 -65
  12. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/README.rst +0 -0
  13. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/core/__init__.py +0 -0
  14. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/core/empirical_margin_distribution.py +0 -0
  15. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/core/ftc_merger_investigations_data.py +0 -0
  16. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/core/guidelines_boundary_functions.py +0 -0
  17. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/core/guidelines_boundary_functions_extra.py +0 -0
  18. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/data/__init__.py +0 -0
  19. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/data/damodaran_margin_data.xls +0 -0
  20. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/data/damodaran_margin_data_dict.msgpack +0 -0
  21. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/data/ftc_invdata.msgpack +0 -0
  22. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/demo/__init__.py +0 -0
  23. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/demo/visualize_empirical_margin_distribution.py +0 -0
  24. {mergeron-2025.739265.2 → mergeron-2025.739290.1}/src/mergeron/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: mergeron
3
- Version: 2025.739265.2
3
+ Version: 2025.739290.1
4
4
  Summary: Analyze merger enforcement policy using Python
5
5
  License: MIT
6
6
  Keywords: merger policy analysis,merger guidelines,merger screening,policy presumptions,concentration standards,upward pricing pressure,GUPPI
@@ -30,6 +30,7 @@ Requires-Dist: matplotlib (>=3.8)
30
30
  Requires-Dist: mpmath (>=1.3)
31
31
  Requires-Dist: msgpack (>=1.0)
32
32
  Requires-Dist: msgpack-numpy (>=0.4)
33
+ Requires-Dist: ruamel-yaml (>=0.18.10,<0.19.0)
33
34
  Requires-Dist: scipy (>=1.12)
34
35
  Requires-Dist: sympy (>=1.12)
35
36
  Requires-Dist: tables (>=3.10.1)
@@ -13,7 +13,7 @@ keywords = [
13
13
  "upward pricing pressure",
14
14
  "GUPPI",
15
15
  ]
16
- version = "2025.739265.2"
16
+ version = "2025.739290.1"
17
17
 
18
18
  # Classifiers list: https://pypi.org/classifiers/
19
19
  classifiers = [
@@ -57,6 +57,7 @@ certifi = ">=2023.11.17"
57
57
  types-beautifulsoup4 = ">=4.11.2"
58
58
  xlrd = "^2.0.1" # Needed to read margin data
59
59
  urllib3 = "^2.2.2"
60
+ ruamel-yaml = "^0.18.10"
60
61
 
61
62
 
62
63
  [tool.poetry.group.dev.dependencies]
@@ -0,0 +1,130 @@
1
+ from __future__ import annotations
2
+
3
+ import enum
4
+ from multiprocessing import cpu_count
5
+ from pathlib import Path
6
+ from typing import Literal
7
+
8
+ import numpy as np
9
+ from numpy.random import SeedSequence
10
+ from numpy.typing import NDArray
11
+ from ruamel import yaml
12
+
13
+ _PKG_NAME: str = Path(__file__).parent.stem
14
+
15
+ VERSION = "2025.739290.1"
16
+
17
+ __version__ = VERSION
18
+
19
+ this_yaml = yaml.YAML(typ="safe", pure=True)
20
+ this_yaml.constructor.deep_construct = True
21
+ this_yaml.indent(mapping=2, sequence=4, offset=2)
22
+
23
+
24
+ DATA_DIR: Path = Path.home() / _PKG_NAME
25
+ """
26
+ Defines a subdirectory named for this package in the user's home path.
27
+
28
+ If the subdirectory doesn't exist, it is created on package invocation.
29
+ """
30
+ if not DATA_DIR.is_dir():
31
+ DATA_DIR.mkdir(parents=False)
32
+
33
+ DEFAULT_REC_RATIO = 0.85
34
+
35
+ EMPTY_ARRAYDOUBLE = np.array([], float)
36
+ EMPTY_ARRAYINT = np.array([], int)
37
+
38
+ NTHREADS = 2 * cpu_count()
39
+
40
+ np.set_printoptions(precision=24, floatmode="fixed")
41
+
42
+ type HMGPubYear = Literal[1982, 1984, 1992, 2010, 2023]
43
+
44
+ type ArrayBoolean = NDArray[np.bool_]
45
+ type ArrayFloat = NDArray[np.float16 | np.float32 | np.float64 | np.float128]
46
+ type ArrayINT = NDArray[np.intp]
47
+
48
+ type ArrayDouble = NDArray[np.float64]
49
+ type ArrayBIGINT = NDArray[np.int64]
50
+
51
+ ## Add yaml representer, constructor for NoneType
52
+ (_, _) = (
53
+ this_yaml.representer.add_representer(
54
+ type(None), lambda _r, _d: _r.represent_scalar("!None", "none")
55
+ ),
56
+ this_yaml.constructor.add_constructor("!None", lambda _c, _n, /: None),
57
+ )
58
+
59
+ ## Add yaml representer, constructor for ndarray
60
+ (_, _) = (
61
+ this_yaml.representer.add_representer(
62
+ np.ndarray,
63
+ lambda _r, _d: _r.represent_sequence("!ndarray", (_d.tolist(), _d.dtype.str)),
64
+ ),
65
+ this_yaml.constructor.add_constructor(
66
+ "!ndarray", lambda _c, _n, /: np.array(*_c.construct_sequence(_n))
67
+ ),
68
+ )
69
+
70
+ # Add yaml representer, constructor for SeedSequence
71
+ this_yaml.representer.add_representer(
72
+ SeedSequence,
73
+ lambda _r, _d: _r.represent_mapping(
74
+ "!SeedSequence",
75
+ {
76
+ _a: getattr(_d, _a)
77
+ for _a in ("entropy", "spawn_key", "pool_size", "n_children_spawned")
78
+ },
79
+ ),
80
+ )
81
+ this_yaml.constructor.add_constructor(
82
+ "!SeedSequence", lambda _c, _n, /: SeedSequence(**_c.construct_mapping(_n))
83
+ )
84
+
85
+
86
+ @enum.unique
87
+ class RECForm(enum.StrEnum):
88
+ """For derivation of recapture ratio from market shares."""
89
+
90
+ INOUT = "inside-out"
91
+ OUTIN = "outside-in"
92
+ FIXED = "proportional"
93
+
94
+
95
+ @enum.unique
96
+ class UPPAggrSelector(enum.StrEnum):
97
+ """
98
+ Aggregator for GUPPI and diversion ratio estimates.
99
+
100
+ """
101
+
102
+ AVG = "average"
103
+ CPA = "cross-product-share weighted average"
104
+ CPD = "cross-product-share weighted distance"
105
+ CPG = "cross-product-share weighted geometric mean"
106
+ DIS = "symmetrically-weighted distance"
107
+ GMN = "geometric mean"
108
+ MAX = "max"
109
+ MIN = "min"
110
+ OSA = "own-share weighted average"
111
+ OSD = "own-share weighted distance"
112
+ OSG = "own-share weighted geometric mean"
113
+
114
+
115
+ for _typ in (RECForm, UPPAggrSelector):
116
+ # NOTE: If additional enums are defined in this module,
117
+ # add themn to the list above
118
+
119
+ _, _ = (
120
+ this_yaml.representer.add_representer(
121
+ _typ,
122
+ lambda _r, _d: _r.represent_scalar(f"!{_d.__class__.__name__}", _d.name),
123
+ ),
124
+ this_yaml.constructor.add_constructor(
125
+ f"!{_typ.__name__}",
126
+ lambda _c, _n, /: getattr(
127
+ globals().get(_n.tag.lstrip("!")), _c.construct_scalar(_n)
128
+ ),
129
+ ),
130
+ )
@@ -13,6 +13,7 @@ from typing import Literal
13
13
  import numpy as np
14
14
  from attrs import Attribute, field, frozen, validators
15
15
  from mpmath import mp, mpf # type: ignore
16
+ from ruamel import yaml
16
17
 
17
18
  from .. import ( # noqa: TID252
18
19
  DEFAULT_REC_RATIO,
@@ -21,6 +22,7 @@ from .. import ( # noqa: TID252
21
22
  HMGPubYear,
22
23
  RECForm,
23
24
  UPPAggrSelector,
25
+ this_yaml,
24
26
  )
25
27
  from . import guidelines_boundary_functions as gbfn
26
28
 
@@ -42,6 +44,7 @@ class HMGThresholds:
42
44
  ipr: float
43
45
 
44
46
 
47
+ @this_yaml.register_class
45
48
  @frozen
46
49
  class GuidelinesThresholds:
47
50
  """
@@ -62,7 +65,7 @@ class GuidelinesThresholds:
62
65
  Year of publication of the Guidelines
63
66
  """
64
67
 
65
- safeharbor: HMGThresholds = field(kw_only=True, default=None)
68
+ safeharbor: HMGThresholds = field(kw_only=True, default=None, init=False)
66
69
  """
67
70
  Negative presumption quantified on various measures
68
71
 
@@ -70,7 +73,7 @@ class GuidelinesThresholds:
70
73
  diversion ratio limit, CMCR, and IPR
71
74
  """
72
75
 
73
- presumption: HMGThresholds = field(kw_only=True, default=None)
76
+ presumption: HMGThresholds = field(kw_only=True, default=None, init=False)
74
77
  """
75
78
  Presumption of harm defined in HMG
76
79
 
@@ -78,7 +81,7 @@ class GuidelinesThresholds:
78
81
  diversion ratio limit, CMCR, and IPR
79
82
  """
80
83
 
81
- imputed_presumption: HMGThresholds = field(kw_only=True, default=None)
84
+ imputed_presumption: HMGThresholds = field(kw_only=True, default=None, init=False)
82
85
  """
83
86
  Presumption of harm imputed from guidelines
84
87
 
@@ -147,31 +150,51 @@ class GuidelinesThresholds:
147
150
  ),
148
151
  )
149
152
 
153
+ @classmethod
154
+ def to_yaml(
155
+ cls, _r: yaml.representer.SafeRepresenter, _d: GuidelinesThresholds
156
+ ) -> yaml.MappingNode:
157
+ _ret: yaml.MappingNode = _r.represent_mapping(
158
+ f"!{cls.__name__}",
159
+ {_a.name: getattr(_d, _a.name) for _a in _d.__attrs_attrs__},
160
+ )
161
+ return _ret
162
+
163
+ @classmethod
164
+ def from_yaml(
165
+ cls, _c: yaml.constructor.SafeConstructor, _n: yaml.MappingNode
166
+ ) -> GuidelinesThresholds:
167
+ return cls(**_c.construct_mapping(_n))
150
168
 
169
+
170
+ @this_yaml.register_class
151
171
  @frozen
152
172
  class ConcentrationBoundary:
153
173
  """Concentration parameters, boundary coordinates, and area under concentration boundary."""
154
174
 
155
175
  measure_name: Literal[
156
- "ΔHHI", "Combined share", "Pre-merger HHI", "Post-merger HHI"
176
+ "ΔHHI",
177
+ "Combined share",
178
+ "Pre-merger HHI Contribution",
179
+ "Post-merger HHI Contribution",
157
180
  ] = field(kw_only=False, default="ΔHHI")
158
181
 
159
182
  @measure_name.validator
160
- def __mnv(
183
+ def _mnv(
161
184
  _instance: ConcentrationBoundary, _attribute: Attribute[str], _value: str, /
162
185
  ) -> None:
163
186
  if _value not in (
164
187
  "ΔHHI",
165
188
  "Combined share",
166
- "Pre-merger HHI",
167
- "Post-merger HHI",
189
+ "Pre-merger HHI Contribution",
190
+ "Post-merger HHI Contribution",
168
191
  ):
169
192
  raise ValueError(f"Invalid name for a concentration measure, {_value!r}.")
170
193
 
171
194
  threshold: float = field(kw_only=False, default=0.01)
172
195
 
173
196
  @threshold.validator
174
- def __tv(
197
+ def _tv(
175
198
  _instance: ConcentrationBoundary, _attribute: Attribute[float], _value: float, /
176
199
  ) -> None:
177
200
  if not 0 <= _value <= 1:
@@ -181,28 +204,49 @@ class ConcentrationBoundary:
181
204
  kw_only=False, default=5, validator=validators.instance_of(int)
182
205
  )
183
206
 
184
- coordinates: ArrayDouble = field(init=False, kw_only=True)
185
- """Market-share pairs as Cartesian coordinates of points on the concentration boundary."""
186
-
187
207
  area: float = field(init=False, kw_only=True)
188
208
  """Area under the concentration boundary."""
189
209
 
210
+ coordinates: ArrayDouble = field(init=False, kw_only=True)
211
+ """Market-share pairs as Cartesian coordinates of points on the concentration boundary."""
212
+
190
213
  def __attrs_post_init__(self, /) -> None:
191
214
  match self.measure_name:
192
215
  case "ΔHHI":
193
216
  _conc_fn = gbfn.hhi_delta_boundary
194
217
  case "Combined share":
195
218
  _conc_fn = gbfn.combined_share_boundary
196
- case "Pre-merger HHI":
219
+ case "Pre-merger HHI Contribution":
197
220
  _conc_fn = gbfn.hhi_pre_contrib_boundary
198
- case "Post-merger HHI":
221
+ case "Post-merger HHI Contribution":
199
222
  _conc_fn = gbfn.hhi_post_contrib_boundary
200
223
 
201
224
  _boundary = _conc_fn(self.threshold, dps=self.precision)
202
- object.__setattr__(self, "coordinates", _boundary.coordinates)
203
225
  object.__setattr__(self, "area", _boundary.area)
226
+ object.__setattr__(self, "coordinates", _boundary.coordinates)
227
+
228
+ @classmethod
229
+ def to_yaml(
230
+ cls, _r: yaml.representer.SafeRepresenter, _d: ConcentrationBoundary
231
+ ) -> yaml.MappingNode:
232
+ _ret: yaml.MappingNode = _r.represent_mapping(
233
+ f"!{cls.__name__}",
234
+ {
235
+ _a.name: getattr(_d, _a.name)
236
+ for _a in _d.__attrs_attrs__
237
+ if _a.name not in ("area", "coordinates")
238
+ },
239
+ )
240
+ return _ret
204
241
 
242
+ @classmethod
243
+ def from_yaml(
244
+ cls, _c: yaml.constructor.SafeConstructor, _n: yaml.MappingNode
245
+ ) -> ConcentrationBoundary:
246
+ return cls(**_c.construct_mapping(_n))
205
247
 
248
+
249
+ @this_yaml.register_class
206
250
  @frozen
207
251
  class DiversionRatioBoundary:
208
252
  """
@@ -221,13 +265,13 @@ class DiversionRatioBoundary:
221
265
  diversion_ratio: float = field(kw_only=False, default=0.065)
222
266
 
223
267
  @diversion_ratio.validator
224
- def __dvv(
268
+ def _dvv(
225
269
  _instance: DiversionRatioBoundary,
226
270
  _attribute: Attribute[float],
227
271
  _value: float,
228
272
  /,
229
273
  ) -> None:
230
- if not (isinstance(_value, float) and 0 <= _value <= 1):
274
+ if not (isinstance(_value, decimal.Decimal | float) and 0 <= _value <= 1):
231
275
  raise ValueError(
232
276
  "Margin-adjusted benchmark share ratio must lie between 0 and 1."
233
277
  )
@@ -260,7 +304,7 @@ class DiversionRatioBoundary:
260
304
  """
261
305
 
262
306
  @recapture_form.validator
263
- def __rsv(
307
+ def _rsv(
264
308
  _instance: DiversionRatioBoundary,
265
309
  _attribute: Attribute[RECForm],
266
310
  _value: RECForm,
@@ -307,12 +351,12 @@ class DiversionRatioBoundary:
307
351
 
308
352
  """
309
353
 
310
- coordinates: ArrayDouble = field(init=False, kw_only=True)
311
- """Market-share pairs as Cartesian coordinates of points on the diversion ratio boundary."""
312
-
313
354
  area: float = field(init=False, kw_only=True)
314
355
  """Area under the diversion ratio boundary."""
315
356
 
357
+ coordinates: ArrayDouble = field(init=False, kw_only=True)
358
+ """Market-share pairs as Cartesian coordinates of points on the diversion ratio boundary."""
359
+
316
360
  def __attrs_post_init__(self, /) -> None:
317
361
  _share_ratio = critical_share_ratio(
318
362
  self.diversion_ratio, r_bar=self.recapture_ratio
@@ -356,8 +400,28 @@ class DiversionRatioBoundary:
356
400
  _upp_agg_kwargs |= {"agg_method": _aggregator, "weighting": _wgt_type}
357
401
 
358
402
  _boundary = _upp_agg_fn(_share_ratio, self.recapture_ratio, **_upp_agg_kwargs)
359
- object.__setattr__(self, "coordinates", _boundary.coordinates)
360
403
  object.__setattr__(self, "area", _boundary.area)
404
+ object.__setattr__(self, "coordinates", _boundary.coordinates)
405
+
406
+ @classmethod
407
+ def to_yaml(
408
+ cls, _r: yaml.representer.SafeRepresenter, _d: DiversionRatioBoundary
409
+ ) -> yaml.MappingNode:
410
+ _ret: yaml.MappingNode = _r.represent_mapping(
411
+ f"!{cls.__name__}",
412
+ {
413
+ _a.name: getattr(_d, _a.name)
414
+ for _a in _d.__attrs_attrs__
415
+ if _a.name not in ("area", "coordinates")
416
+ },
417
+ )
418
+ return _ret
419
+
420
+ @classmethod
421
+ def from_yaml(
422
+ cls, _c: yaml.constructor.SafeConstructor, _n: yaml.MappingNode
423
+ ) -> DiversionRatioBoundary:
424
+ return cls(**_c.construct_mapping(_n))
361
425
 
362
426
 
363
427
  def guppi_from_delta(
@@ -10,20 +10,18 @@ from __future__ import annotations
10
10
 
11
11
  import concurrent.futures
12
12
  from collections.abc import Sequence
13
- from multiprocessing import cpu_count
14
13
  from typing import Literal
15
14
 
16
15
  import numpy as np
17
- from attrs import Attribute, define, field
16
+ from attrs import Attribute, Converter, define, field
18
17
  from numpy.random import PCG64DXSM, Generator, SeedSequence
19
18
 
20
- from .. import VERSION, ArrayDouble # noqa: TID252
19
+ from .. import NTHREADS, VERSION, ArrayDouble, ArrayFloat # noqa: TID252
21
20
 
22
21
  __version__ = VERSION
23
22
 
24
- NTHREADS = 2 * cpu_count()
25
- DEFAULT_DIST_PARMS: ArrayDouble = np.array([0.0, 1.0], float)
26
- DEFAULT_BETA_DIST_PARMS: ArrayDouble = np.array([1.0, 1.0], float)
23
+ DEFAULT_DIST_PARMS: ArrayFloat = np.array([0.0, 1.0], float)
24
+ DEFAULT_BETA_DIST_PARMS: ArrayFloat = np.array([1.0, 1.0], float)
27
25
 
28
26
 
29
27
  def prng(_s: SeedSequence | None = None, /) -> np.random.Generator:
@@ -110,6 +108,20 @@ def gen_seed_seq_list_default(
110
108
  return [SeedSequence(_s, pool_size=8) for _s in generated_entropy[:_sseq_list_len]]
111
109
 
112
110
 
111
+ def _dist_parms_conv(_v: ArrayFloat, _i: MultithreadedRNG) -> ArrayFloat:
112
+ if not len(_v):
113
+ return {
114
+ "Beta": DEFAULT_BETA_DIST_PARMS,
115
+ "Dirichlet": np.ones(_i.values.shape[-1], float),
116
+ }.get(_i.dist_type, DEFAULT_DIST_PARMS)
117
+ elif isinstance(_v, Sequence | np.ndarray):
118
+ return np.asarray(_v, float)
119
+ else:
120
+ raise ValueError(
121
+ "Input, {_v!r} has invalid type. Must be None, Sequence of floats, or Numpy ndarray."
122
+ )
123
+
124
+
113
125
  @define
114
126
  class MultithreadedRNG:
115
127
  """Fill given array on demand with pseudo-random numbers as specified.
@@ -121,22 +133,32 @@ class MultithreadedRNG:
121
133
  before commencing multithreaded random number generation.
122
134
  """
123
135
 
124
- values: ArrayDouble = field(kw_only=False, default=None)
136
+ values: ArrayDouble = field(kw_only=False)
125
137
  """Output array to which generated data are over-written
126
138
 
127
139
  Array-length defines the number of i.i.d. (vector) draws.
128
140
  """
129
141
 
142
+ @values.validator
143
+ def _vsv(
144
+ _instance: MultithreadedRNG,
145
+ _attribute: Attribute[ArrayDouble],
146
+ _value: ArrayDouble,
147
+ /,
148
+ ) -> None:
149
+ if not len(_value):
150
+ raise ValueError("Output array must at least be one dimension")
151
+
130
152
  dist_type: Literal[
131
153
  "Beta", "Dirichlet", "Gaussian", "Normal", "Random", "Uniform"
132
- ] = field(kw_only=True, default="Uniform")
154
+ ] = field(default="Uniform")
133
155
  """Distribution for the generated random numbers.
134
156
 
135
157
  Default is "Uniform".
136
158
  """
137
159
 
138
160
  @dist_type.validator
139
- def __dtv(
161
+ def _dtv(
140
162
  _instance: MultithreadedRNG, _attribute: Attribute[str], _value: str, /
141
163
  ) -> None:
142
164
  if _value not in (
@@ -144,60 +166,48 @@ class MultithreadedRNG:
144
166
  ):
145
167
  raise ValueError(f"Specified distribution must be one of {_rdts}")
146
168
 
147
- dist_parms: ArrayDouble | None = field(kw_only=True, default=DEFAULT_DIST_PARMS)
169
+ dist_parms: ArrayFloat = field(
170
+ converter=Converter(_dist_parms_conv, takes_self=True) # type: ignore
171
+ )
148
172
  """Parameters, if any, for tailoring random number generation
149
173
  """
150
174
 
175
+ @dist_parms.default
176
+ def _dpd(_instance: MultithreadedRNG) -> ArrayFloat:
177
+ return {
178
+ "Beta": DEFAULT_BETA_DIST_PARMS,
179
+ "Dirichlet": np.ones(_instance.values.shape[-1], float),
180
+ }.get(_instance.dist_type, DEFAULT_DIST_PARMS)
181
+
151
182
  @dist_parms.validator
152
- def __dpv(
153
- _instance: MultithreadedRNG, _attribute: Attribute[str], _value: ArrayDouble, /
183
+ def _dpv(
184
+ _instance: MultithreadedRNG,
185
+ _attribute: Attribute[ArrayFloat],
186
+ _value: ArrayFloat,
187
+ /,
154
188
  ) -> None:
155
- if _value is not None:
156
- if not isinstance(_value, Sequence | np.ndarray):
157
- raise ValueError(
158
- "When specified, distribution parameters must be a list, tuple or Numpy array"
159
- )
189
+ if (
190
+ _instance.dist_type != "Dirichlet"
191
+ and (_lrdp := len(_value)) != (_trdp := 2)
192
+ ) or (
193
+ _instance.dist_type == "Dirichlet"
194
+ and (_lrdp := len(_value)) != (_trdp := _instance.values.shape[1])
195
+ ):
196
+ raise ValueError(f"Expected {_trdp} parameters, got, {_lrdp}")
160
197
 
161
- elif (
162
- _instance.dist_type != "Dirichlet"
163
- and (_lrdp := len(_value)) != (_trdp := 2)
164
- ) or (
165
- _instance.dist_type == "Dirichlet"
166
- and (_lrdp := len(_value)) != (_trdp := _instance.values.shape[1])
167
- ):
168
- raise ValueError(f"Expected {_trdp} parameters, got, {_lrdp}")
169
-
170
- elif (
171
- _instance.dist_type in ("Beta", "Dirichlet")
172
- and (np.array(_value) <= 0.0).any()
173
- ):
174
- raise ValueError(
175
- "Shape and location parameters must be strictly positive"
176
- )
198
+ elif _instance.dist_type in ("Beta", "Dirichlet") and (_value <= 0.0).any():
199
+ raise ValueError("Shape and location parameters must be strictly positive")
177
200
 
178
- seed_sequence: SeedSequence | None = field(kw_only=True, default=None)
201
+ seed_sequence: SeedSequence | None = field(default=None)
179
202
  """Seed sequence for generating random numbers."""
180
203
 
181
- nthreads: int = field(kw_only=True, default=NTHREADS)
204
+ nthreads: int = field(default=NTHREADS)
182
205
  """Number of threads to spawn for random number generation."""
183
206
 
184
207
  def fill(self) -> None:
185
208
  """Fill the provided output array with random number draws as specified."""
186
209
 
187
- if (
188
- self.dist_parms is None
189
- or not (
190
- _dist_parms := np.array(self.dist_parms) # one-shot conversion
191
- ).any()
192
- ):
193
- if self.dist_type == "Beta":
194
- _dist_parms = DEFAULT_BETA_DIST_PARMS
195
- elif self.dist_type == "Dirichlet":
196
- _dist_parms = np.ones(self.values.shape[1], float)
197
- else:
198
- _dist_parms = DEFAULT_DIST_PARMS
199
-
200
- if self.dist_parms is None or np.array_equal(
210
+ if not len(self.dist_parms) or np.array_equal(
201
211
  self.dist_parms, DEFAULT_DIST_PARMS
202
212
  ):
203
213
  if self.dist_type == "Uniform":
@@ -219,7 +229,7 @@ class MultithreadedRNG:
219
229
  def _fill(
220
230
  _rng: np.random.Generator,
221
231
  _dist_type: str,
222
- _dist_parms: ArrayDouble,
232
+ _dist_parms: ArrayFloat,
223
233
  _out: ArrayDouble,
224
234
  _first: int,
225
235
  _last: int,
@@ -254,7 +264,7 @@ class MultithreadedRNG:
254
264
  _fill,
255
265
  _random_generators[i],
256
266
  _dist_type,
257
- _dist_parms,
267
+ self.dist_parms,
258
268
  self.values,
259
269
  _range_first,
260
270
  _range_last,