gundam-interface 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gundam_interface/__init__.py +38 -0
- gundam_interface/_version.py +8 -0
- gundam_interface/config.py +145 -0
- gundam_interface/interface.py +347 -0
- gundam_interface/logging.py +148 -0
- gundam_interface/parameters.py +99 -0
- gundam_interface/py.typed +1 -0
- gundam_interface-0.1.0.dist-info/METADATA +70 -0
- gundam_interface-0.1.0.dist-info/RECORD +12 -0
- gundam_interface-0.1.0.dist-info/WHEEL +5 -0
- gundam_interface-0.1.0.dist-info/licenses/LICENCE +503 -0
- gundam_interface-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Public package interface for gundam-interface."""
|
|
2
|
+
|
|
3
|
+
from ._version import __version__
|
|
4
|
+
from .config import GundamContext
|
|
5
|
+
from .interface import GundamInterface, PostfitThrowSamples
|
|
6
|
+
from .logging import (
|
|
7
|
+
isNotebookRuntime,
|
|
8
|
+
maybeRedirectNativeOutput,
|
|
9
|
+
redirectNativeOutput,
|
|
10
|
+
temporaryRedirectNativeOutput,
|
|
11
|
+
)
|
|
12
|
+
from .parameters import (
|
|
13
|
+
GundamParameter,
|
|
14
|
+
collectActiveParameters,
|
|
15
|
+
normalizedToPhysical,
|
|
16
|
+
parameterPriors,
|
|
17
|
+
parameterSteps,
|
|
18
|
+
parameterThrowValues,
|
|
19
|
+
physicalToNormalized,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"__version__",
|
|
24
|
+
"GundamContext",
|
|
25
|
+
"GundamInterface",
|
|
26
|
+
"GundamParameter",
|
|
27
|
+
"PostfitThrowSamples",
|
|
28
|
+
"collectActiveParameters",
|
|
29
|
+
"isNotebookRuntime",
|
|
30
|
+
"maybeRedirectNativeOutput",
|
|
31
|
+
"normalizedToPhysical",
|
|
32
|
+
"parameterPriors",
|
|
33
|
+
"parameterSteps",
|
|
34
|
+
"parameterThrowValues",
|
|
35
|
+
"physicalToNormalized",
|
|
36
|
+
"redirectNativeOutput",
|
|
37
|
+
"temporaryRedirectNativeOutput",
|
|
38
|
+
]
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(slots=True)
|
|
10
|
+
class GundamContext:
|
|
11
|
+
"""Runtime context needed to construct the GUNDAM Python interface."""
|
|
12
|
+
|
|
13
|
+
nCpuThreads: int
|
|
14
|
+
pythonPath: str | Path
|
|
15
|
+
workDir: str | Path
|
|
16
|
+
configPath: str | Path | None = None
|
|
17
|
+
overrideList: list[str | Path] = field(default_factory=list)
|
|
18
|
+
configJsonString: str | None = None
|
|
19
|
+
forceAsimov: bool | None = None
|
|
20
|
+
dataType: str | None = None
|
|
21
|
+
randomSeed: int | None = None
|
|
22
|
+
|
|
23
|
+
def __post_init__(self) -> None:
|
|
24
|
+
self.pythonPath = Path(self.pythonPath).expanduser()
|
|
25
|
+
self.workDir = Path(self.workDir).expanduser()
|
|
26
|
+
if self.configPath is not None:
|
|
27
|
+
self.configPath = Path(self.configPath).expanduser()
|
|
28
|
+
self.overrideList = [Path(path).expanduser() for path in self.overrideList]
|
|
29
|
+
if self.configJsonString is not None:
|
|
30
|
+
self.configJsonString = self.configJsonString.strip()
|
|
31
|
+
|
|
32
|
+
if self.nCpuThreads < 1:
|
|
33
|
+
raise ValueError("nCpuThreads must be >= 1")
|
|
34
|
+
if self.randomSeed is not None:
|
|
35
|
+
self.randomSeed = int(self.randomSeed)
|
|
36
|
+
if self.randomSeed < 0:
|
|
37
|
+
raise ValueError("randomSeed must be >= 0")
|
|
38
|
+
if self.configPath is None and self.configJsonString is None:
|
|
39
|
+
raise ValueError("Either configPath or configJsonString must be provided")
|
|
40
|
+
self.dataType = self._canonicalDataType(self.dataType, self.forceAsimov)
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def fromDict(cls, data: dict[str, Any]) -> "GundamContext":
|
|
44
|
+
return cls(
|
|
45
|
+
nCpuThreads=int(data["nCpuThreads"]),
|
|
46
|
+
pythonPath=data["pythonPath"],
|
|
47
|
+
workDir=data["workDir"],
|
|
48
|
+
configPath=data.get("configPath"),
|
|
49
|
+
overrideList=list(data.get("overrideList", [])),
|
|
50
|
+
configJsonString=data.get("configJsonString"),
|
|
51
|
+
forceAsimov=data.get("forceAsimov", data.get("useAsimov")),
|
|
52
|
+
dataType=data.get("dataType"),
|
|
53
|
+
randomSeed=data.get("randomSeed", data.get("seed")),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def fromJsonFile(cls, path: str | Path) -> "GundamContext":
|
|
58
|
+
with Path(path).open("r", encoding="utf-8") as file:
|
|
59
|
+
return cls.fromDict(json.load(file))
|
|
60
|
+
|
|
61
|
+
def toDict(self, includeConfigJsonString: bool = True) -> dict[str, Any]:
|
|
62
|
+
data = {
|
|
63
|
+
"nCpuThreads": self.nCpuThreads,
|
|
64
|
+
"pythonPath": str(self.pythonPath),
|
|
65
|
+
"workDir": str(self.workDir),
|
|
66
|
+
"dataType": self.dataType,
|
|
67
|
+
}
|
|
68
|
+
if self.randomSeed is not None:
|
|
69
|
+
data["randomSeed"] = self.randomSeed
|
|
70
|
+
if self.configJsonString is not None:
|
|
71
|
+
if includeConfigJsonString:
|
|
72
|
+
data["configJsonString"] = self.configJsonString
|
|
73
|
+
else:
|
|
74
|
+
data["configPath"] = str(self.configPath)
|
|
75
|
+
data["overrideList"] = [str(path) for path in self.overrideList]
|
|
76
|
+
return data
|
|
77
|
+
|
|
78
|
+
def toJsonFile(self, path: str | Path) -> None:
|
|
79
|
+
path = Path(path)
|
|
80
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
81
|
+
with path.open("w", encoding="utf-8") as file:
|
|
82
|
+
json.dump(self.toDict(), file, indent=2, sort_keys=True)
|
|
83
|
+
file.write("\n")
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
def _canonicalDataType(dataType: str | None, forceAsimov: bool | None) -> str:
|
|
87
|
+
if dataType is None:
|
|
88
|
+
if forceAsimov is None or forceAsimov:
|
|
89
|
+
return "Asimov"
|
|
90
|
+
return "RealData"
|
|
91
|
+
|
|
92
|
+
normalized = dataType.replace("_", "").replace("-", "").lower()
|
|
93
|
+
aliases = {
|
|
94
|
+
"asimov": "Asimov",
|
|
95
|
+
"toy": "Toy",
|
|
96
|
+
"realdata": "RealData",
|
|
97
|
+
"real": "RealData",
|
|
98
|
+
"data": "RealData",
|
|
99
|
+
}
|
|
100
|
+
try:
|
|
101
|
+
return aliases[normalized]
|
|
102
|
+
except KeyError as error:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
"dataType must be one of: Asimov, Toy, RealData"
|
|
105
|
+
) from error
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def absoluteConfigPath(self) -> Path:
|
|
109
|
+
if self.configPath is None:
|
|
110
|
+
raise ValueError("No configPath is defined for this GundamContext")
|
|
111
|
+
if self.configPath.is_absolute():
|
|
112
|
+
return self.configPath
|
|
113
|
+
return self.workDir / self.configPath
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def absoluteOverridePaths(self) -> list[Path]:
|
|
117
|
+
overridePaths = []
|
|
118
|
+
for overridePath in self.overrideList:
|
|
119
|
+
if overridePath.is_absolute():
|
|
120
|
+
overridePaths.append(overridePath)
|
|
121
|
+
else:
|
|
122
|
+
overridePaths.append(self.workDir / overridePath)
|
|
123
|
+
return overridePaths
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def defaultInitializeLogPath(self) -> Path:
|
|
127
|
+
return self.workDir / "gundam_initialize.log"
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def defaultEvaluateLogPath(self) -> Path:
|
|
131
|
+
return self.workDir / "gundam_evaluate.log"
|
|
132
|
+
|
|
133
|
+
def validatePaths(self) -> None:
|
|
134
|
+
"""Fail early on missing user-provided paths."""
|
|
135
|
+
if not self.pythonPath.exists():
|
|
136
|
+
raise FileNotFoundError(f"GUNDAM pythonPath does not exist: {self.pythonPath}")
|
|
137
|
+
if not self.workDir.exists():
|
|
138
|
+
raise FileNotFoundError(f"GUNDAM workDir does not exist: {self.workDir}")
|
|
139
|
+
if self.configJsonString is not None:
|
|
140
|
+
return
|
|
141
|
+
if not self.absoluteConfigPath.exists():
|
|
142
|
+
raise FileNotFoundError(f"GUNDAM config file does not exist: {self.absoluteConfigPath}")
|
|
143
|
+
for overridePath in self.absoluteOverridePaths:
|
|
144
|
+
if not overridePath.exists():
|
|
145
|
+
raise FileNotFoundError(f"GUNDAM override file does not exist: {overridePath}")
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
import tempfile
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Iterator
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
from .config import GundamContext
|
|
15
|
+
from .logging import maybeRedirectNativeOutput, temporaryRedirectNativeOutput
|
|
16
|
+
from .parameters import (
|
|
17
|
+
GundamParameter,
|
|
18
|
+
collectActiveParameters,
|
|
19
|
+
normalizedToPhysical,
|
|
20
|
+
parameterPriors,
|
|
21
|
+
parameterSteps,
|
|
22
|
+
parameterThrowValues,
|
|
23
|
+
physicalToNormalized,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True, slots=True)
|
|
28
|
+
class PostfitThrowSamples:
|
|
29
|
+
"""GUNDAM post-fit throws with propagated likelihood evaluations."""
|
|
30
|
+
|
|
31
|
+
physicalValues: np.ndarray
|
|
32
|
+
normalizedValues: np.ndarray
|
|
33
|
+
llh: np.ndarray
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@contextmanager
|
|
37
|
+
def preservedWorkingDirectory() -> Iterator[None]:
|
|
38
|
+
originalWorkingDirectory = Path.cwd()
|
|
39
|
+
try:
|
|
40
|
+
yield
|
|
41
|
+
finally:
|
|
42
|
+
os.chdir(originalWorkingDirectory)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@contextmanager
|
|
46
|
+
def temporaryWorkingDirectory(path: str | os.PathLike[str]) -> Iterator[None]:
|
|
47
|
+
originalWorkingDirectory = Path.cwd()
|
|
48
|
+
os.chdir(Path(path).expanduser().resolve())
|
|
49
|
+
try:
|
|
50
|
+
yield
|
|
51
|
+
finally:
|
|
52
|
+
os.chdir(originalWorkingDirectory)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class GundamInterface:
|
|
56
|
+
"""Thin Python wrapper around the GUNDAM fitting interface."""
|
|
57
|
+
|
|
58
|
+
def __init__(self, context: GundamContext):
|
|
59
|
+
self.context = context
|
|
60
|
+
self.gundam: Any | None = None
|
|
61
|
+
self.configBuilder: Any | None = None
|
|
62
|
+
self.configJsonString: str | None = None
|
|
63
|
+
self.fitterEngineConfig: Any | None = None
|
|
64
|
+
self.engine: Any | None = None
|
|
65
|
+
self.parameters: list[GundamParameter] = []
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def isConfigured(self) -> bool:
|
|
69
|
+
return self.engine is not None
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def isInitialized(self) -> bool:
|
|
73
|
+
return bool(self.parameters)
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def priors(self) -> np.ndarray:
|
|
77
|
+
return parameterPriors(self.parameters)
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def stepSizes(self) -> np.ndarray:
|
|
81
|
+
return parameterSteps(self.parameters)
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def throwValues(self) -> np.ndarray | None:
|
|
85
|
+
return parameterThrowValues(self.parameters)
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def parameterNames(self) -> list[str]:
|
|
89
|
+
return [parameter.name for parameter in self.parameters]
|
|
90
|
+
|
|
91
|
+
def setupPythonPath(self) -> None:
|
|
92
|
+
pythonPath = str(self.context.pythonPath)
|
|
93
|
+
if pythonPath not in sys.path:
|
|
94
|
+
sys.path.insert(0, pythonPath)
|
|
95
|
+
|
|
96
|
+
existingPythonPath = os.environ.get("PYTHONPATH", "")
|
|
97
|
+
pythonPathParts = [part for part in existingPythonPath.split(os.pathsep) if part]
|
|
98
|
+
if pythonPath not in pythonPathParts:
|
|
99
|
+
os.environ["PYTHONPATH"] = os.pathsep.join([pythonPath, *pythonPathParts])
|
|
100
|
+
|
|
101
|
+
def importGundam(self):
|
|
102
|
+
self.setupPythonPath()
|
|
103
|
+
self.gundam = importlib.import_module("GUNDAM")
|
|
104
|
+
return self.gundam
|
|
105
|
+
|
|
106
|
+
def configure(self, validatePaths: bool = True) -> None:
|
|
107
|
+
with preservedWorkingDirectory():
|
|
108
|
+
if validatePaths:
|
|
109
|
+
self.context.validatePaths()
|
|
110
|
+
|
|
111
|
+
gundam = self.importGundam()
|
|
112
|
+
gundam.setLightOutputMode(False)
|
|
113
|
+
gundam.setNumberOfThreads(self.context.nCpuThreads)
|
|
114
|
+
workingDirectory = Path(self.context.workDir).expanduser().resolve()
|
|
115
|
+
gundam.setRuntimeWorkingDirectory(str(workingDirectory))
|
|
116
|
+
|
|
117
|
+
with temporaryWorkingDirectory(workingDirectory):
|
|
118
|
+
configBuilder = self._buildConfigBuilder(gundam)
|
|
119
|
+
configJsonString = configBuilder.toString()
|
|
120
|
+
|
|
121
|
+
configReader = gundam.ConfigUtils.ConfigReader(configBuilder.getConfig())
|
|
122
|
+
configReader.defineField(
|
|
123
|
+
gundam.ConfigUtils.ConfigReader.FieldDefinition("fitterEngineConfig")
|
|
124
|
+
)
|
|
125
|
+
fitterEngineConfig = configReader.fetchValueConfigReader("fitterEngineConfig")
|
|
126
|
+
|
|
127
|
+
engine = gundam.FitterEngine()
|
|
128
|
+
engine.setConfig(fitterEngineConfig)
|
|
129
|
+
self._setEngineRandomSeed(engine, self.context.randomSeed)
|
|
130
|
+
with temporaryWorkingDirectory(workingDirectory):
|
|
131
|
+
engine.configure()
|
|
132
|
+
|
|
133
|
+
self.configBuilder = configBuilder
|
|
134
|
+
self.configJsonString = configJsonString
|
|
135
|
+
self.fitterEngineConfig = fitterEngineConfig
|
|
136
|
+
self.engine = engine
|
|
137
|
+
|
|
138
|
+
def _buildConfigBuilder(self, gundam):
|
|
139
|
+
if self.context.configJsonString is not None:
|
|
140
|
+
return self._buildConfigBuilderFromJsonString(gundam, self.context.configJsonString)
|
|
141
|
+
|
|
142
|
+
configPath = Path(self.context.absoluteConfigPath).expanduser().resolve()
|
|
143
|
+
overridePaths = [
|
|
144
|
+
Path(overridePath).expanduser().resolve()
|
|
145
|
+
for overridePath in self.context.absoluteOverridePaths
|
|
146
|
+
]
|
|
147
|
+
configBuilder = gundam.ConfigUtils.ConfigBuilder(str(configPath))
|
|
148
|
+
for overridePath in overridePaths:
|
|
149
|
+
configBuilder.override(str(overridePath))
|
|
150
|
+
return configBuilder
|
|
151
|
+
|
|
152
|
+
@staticmethod
|
|
153
|
+
def _buildConfigBuilderFromJsonString(gundam, configJsonString: str):
|
|
154
|
+
# The Python binding exposes ConfigBuilder(str), but that overload expects a file path.
|
|
155
|
+
# Keep the public API string-based and isolate the temporary bridge here.
|
|
156
|
+
with tempfile.NamedTemporaryFile(
|
|
157
|
+
mode="w",
|
|
158
|
+
suffix=".json",
|
|
159
|
+
encoding="utf-8",
|
|
160
|
+
delete=True,
|
|
161
|
+
) as configFile:
|
|
162
|
+
configFile.write(configJsonString)
|
|
163
|
+
configFile.flush()
|
|
164
|
+
return gundam.ConfigUtils.ConfigBuilder(str(configFile.name))
|
|
165
|
+
|
|
166
|
+
def initialize(
|
|
167
|
+
self,
|
|
168
|
+
logPath: str | os.PathLike[str] | None = None,
|
|
169
|
+
) -> None:
|
|
170
|
+
with preservedWorkingDirectory():
|
|
171
|
+
self._requireConfigured()
|
|
172
|
+
workingDirectory = Path(self.context.workDir).expanduser().resolve()
|
|
173
|
+
|
|
174
|
+
if logPath is None:
|
|
175
|
+
redirectContext = temporaryRedirectNativeOutput("gundam_initialize")
|
|
176
|
+
else:
|
|
177
|
+
logPath = Path(logPath).expanduser().resolve()
|
|
178
|
+
redirectContext = maybeRedirectNativeOutput(logPath)
|
|
179
|
+
|
|
180
|
+
with temporaryWorkingDirectory(workingDirectory):
|
|
181
|
+
with redirectContext:
|
|
182
|
+
self._setLikelihoodDataType()
|
|
183
|
+
self.engine.initialize()
|
|
184
|
+
|
|
185
|
+
self.refreshParameters()
|
|
186
|
+
|
|
187
|
+
def refreshParameters(self) -> list[GundamParameter]:
|
|
188
|
+
self._requireConfigured()
|
|
189
|
+
parametersManager = (
|
|
190
|
+
self.engine.getLikelihoodInterface()
|
|
191
|
+
.getModelPropagator()
|
|
192
|
+
.getParametersManager()
|
|
193
|
+
)
|
|
194
|
+
self.parameters = collectActiveParameters(
|
|
195
|
+
parametersManager,
|
|
196
|
+
includeThrowValues=self.context.dataType == "Toy",
|
|
197
|
+
)
|
|
198
|
+
return self.parameters
|
|
199
|
+
|
|
200
|
+
def getParameterValues(self) -> np.ndarray:
|
|
201
|
+
self._requireParameters()
|
|
202
|
+
return np.array([parameter.value for parameter in self.parameters], dtype=np.float64)
|
|
203
|
+
|
|
204
|
+
def setParameterValues(self, values: np.ndarray) -> None:
|
|
205
|
+
self._requireParameters()
|
|
206
|
+
values = np.asarray(values, dtype=np.float64)
|
|
207
|
+
if values.shape != self.priors.shape:
|
|
208
|
+
raise ValueError(f"Expected parameter shape {self.priors.shape}, got {values.shape}")
|
|
209
|
+
for parameter, value in zip(self.parameters, values):
|
|
210
|
+
parameter.setValue(float(value))
|
|
211
|
+
|
|
212
|
+
def resetToPrior(self) -> None:
|
|
213
|
+
self._requireParameters()
|
|
214
|
+
for parameter in self.parameters:
|
|
215
|
+
parameter.resetToPrior()
|
|
216
|
+
|
|
217
|
+
def normalizedToPhysical(self, normalizedValues: np.ndarray) -> np.ndarray:
|
|
218
|
+
self._requireParameters()
|
|
219
|
+
return normalizedToPhysical(normalizedValues, self.priors, self.stepSizes)
|
|
220
|
+
|
|
221
|
+
def physicalToNormalized(self, physicalValues: np.ndarray) -> np.ndarray:
|
|
222
|
+
self._requireParameters()
|
|
223
|
+
return physicalToNormalized(physicalValues, self.priors, self.stepSizes)
|
|
224
|
+
|
|
225
|
+
def evaluateLlh(
|
|
226
|
+
self,
|
|
227
|
+
physicalValues: np.ndarray | None = None,
|
|
228
|
+
normalizedValues: np.ndarray | None = None,
|
|
229
|
+
logPath: str | os.PathLike[str] | None = None,
|
|
230
|
+
) -> float:
|
|
231
|
+
with preservedWorkingDirectory():
|
|
232
|
+
self._requireParameters()
|
|
233
|
+
if physicalValues is not None and normalizedValues is not None:
|
|
234
|
+
raise ValueError("Provide either physicalValues or normalizedValues, not both")
|
|
235
|
+
if normalizedValues is not None:
|
|
236
|
+
physicalValues = self.normalizedToPhysical(normalizedValues)
|
|
237
|
+
if physicalValues is not None:
|
|
238
|
+
self.setParameterValues(physicalValues)
|
|
239
|
+
|
|
240
|
+
if logPath is not None:
|
|
241
|
+
logPath = Path(logPath).expanduser().resolve()
|
|
242
|
+
workingDirectory = Path(self.context.workDir).expanduser().resolve()
|
|
243
|
+
|
|
244
|
+
with temporaryWorkingDirectory(workingDirectory):
|
|
245
|
+
with maybeRedirectNativeOutput(logPath):
|
|
246
|
+
self.engine.getLikelihoodInterface().propagateAndEvalLikelihood()
|
|
247
|
+
return float(self.engine.getLikelihoodInterface().getLastLikelihood())
|
|
248
|
+
|
|
249
|
+
def minimize(
|
|
250
|
+
self,
|
|
251
|
+
logPath: str | os.PathLike[str] | None = None,
|
|
252
|
+
) -> float:
|
|
253
|
+
with preservedWorkingDirectory():
|
|
254
|
+
self._requireParameters()
|
|
255
|
+
if logPath is not None:
|
|
256
|
+
logPath = Path(logPath).expanduser().resolve()
|
|
257
|
+
workingDirectory = Path(self.context.workDir).expanduser().resolve()
|
|
258
|
+
|
|
259
|
+
with temporaryWorkingDirectory(workingDirectory):
|
|
260
|
+
with maybeRedirectNativeOutput(logPath):
|
|
261
|
+
self.engine.getMinimizer().minimize()
|
|
262
|
+
|
|
263
|
+
self.refreshParameters()
|
|
264
|
+
return float(self.engine.getLikelihoodInterface().getLastLikelihood())
|
|
265
|
+
|
|
266
|
+
def evaluatePostfitThrows(
|
|
267
|
+
self,
|
|
268
|
+
nThrows: int,
|
|
269
|
+
logPath: str | os.PathLike[str] | None = None,
|
|
270
|
+
showProgress: bool = True,
|
|
271
|
+
) -> PostfitThrowSamples:
|
|
272
|
+
"""Throw post-fit parameters, propagate them, and evaluate their LLH.
|
|
273
|
+
|
|
274
|
+
The GUNDAM binding only exposes ``throwPostfitParameters()`` as a state
|
|
275
|
+
update on the minimizer. This method wraps that operation into a simple
|
|
276
|
+
batch interface. ``logPath`` is accepted for backward compatibility but
|
|
277
|
+
is intentionally ignored: native output is not redirected in this loop.
|
|
278
|
+
"""
|
|
279
|
+
from tqdm.auto import tqdm
|
|
280
|
+
|
|
281
|
+
with preservedWorkingDirectory():
|
|
282
|
+
self._requireParameters()
|
|
283
|
+
if nThrows < 1:
|
|
284
|
+
raise ValueError("nThrows must be >= 1")
|
|
285
|
+
workingDirectory = Path(self.context.workDir).expanduser().resolve()
|
|
286
|
+
|
|
287
|
+
physicalValues = np.empty((nThrows, self.priors.shape[0]), dtype=np.float64)
|
|
288
|
+
normalizedValues = np.empty_like(physicalValues)
|
|
289
|
+
llh = np.empty(nThrows, dtype=np.float64)
|
|
290
|
+
|
|
291
|
+
with temporaryWorkingDirectory(workingDirectory):
|
|
292
|
+
minimizer = self.engine.getMinimizer()
|
|
293
|
+
likelihoodInterface = self.engine.getLikelihoodInterface()
|
|
294
|
+
throwIterator = range(nThrows)
|
|
295
|
+
if showProgress:
|
|
296
|
+
throwIterator = tqdm(
|
|
297
|
+
throwIterator,
|
|
298
|
+
desc="GUNDAM post-fit throws",
|
|
299
|
+
unit="throw",
|
|
300
|
+
)
|
|
301
|
+
for throwIndex in throwIterator:
|
|
302
|
+
minimizer.throwPostfitParameters()
|
|
303
|
+
physicalValues[throwIndex] = self.getParameterValues()
|
|
304
|
+
normalizedValues[throwIndex] = self.physicalToNormalized(
|
|
305
|
+
physicalValues[throwIndex]
|
|
306
|
+
)
|
|
307
|
+
likelihoodInterface.propagateAndEvalLikelihood()
|
|
308
|
+
llh[throwIndex] = float(likelihoodInterface.getLastLikelihood())
|
|
309
|
+
|
|
310
|
+
self.refreshParameters()
|
|
311
|
+
return PostfitThrowSamples(
|
|
312
|
+
physicalValues=physicalValues,
|
|
313
|
+
normalizedValues=normalizedValues,
|
|
314
|
+
llh=llh,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def setSeed(self, seed: int | None = None) -> None:
|
|
318
|
+
self._requireConfigured()
|
|
319
|
+
seed = self.context.randomSeed if seed is None else seed
|
|
320
|
+
self._setEngineRandomSeed(self.engine, seed)
|
|
321
|
+
|
|
322
|
+
@staticmethod
|
|
323
|
+
def _setEngineRandomSeed(engine, seed: int | None) -> None:
|
|
324
|
+
if seed is None:
|
|
325
|
+
return
|
|
326
|
+
seed = int(seed)
|
|
327
|
+
if seed < 0:
|
|
328
|
+
raise ValueError("seed must be >= 0")
|
|
329
|
+
type(engine).setRandomSeed(seed)
|
|
330
|
+
|
|
331
|
+
def _requireConfigured(self) -> None:
|
|
332
|
+
if self.engine is None:
|
|
333
|
+
raise RuntimeError("GundamInterface.configure() must be called first")
|
|
334
|
+
|
|
335
|
+
def _setLikelihoodDataType(self) -> None:
|
|
336
|
+
self._requireConfigured()
|
|
337
|
+
gundam = self.importGundam()
|
|
338
|
+
likelihoodInterface = self.engine.getLikelihoodInterface()
|
|
339
|
+
dataType = getattr(gundam.LikelihoodInterface.DataType, self.context.dataType)
|
|
340
|
+
likelihoodInterface.setDataType(dataType)
|
|
341
|
+
|
|
342
|
+
def _requireParameters(self) -> None:
|
|
343
|
+
self._requireConfigured()
|
|
344
|
+
if not self.parameters:
|
|
345
|
+
raise RuntimeError(
|
|
346
|
+
"No active parameters are loaded. Call initialize() or refreshParameters() first."
|
|
347
|
+
)
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import ctypes
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
import tempfile
|
|
7
|
+
import threading
|
|
8
|
+
import time
|
|
9
|
+
import uuid
|
|
10
|
+
from contextlib import contextmanager, nullcontext
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Iterator
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def isNotebookRuntime() -> bool:
|
|
16
|
+
if "ipykernel" in sys.modules:
|
|
17
|
+
return True
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
from IPython import get_ipython
|
|
21
|
+
except ModuleNotFoundError:
|
|
22
|
+
return False
|
|
23
|
+
|
|
24
|
+
shell = get_ipython()
|
|
25
|
+
if shell is None:
|
|
26
|
+
return False
|
|
27
|
+
if shell.__class__.__name__ == "ZMQInteractiveShell":
|
|
28
|
+
return True
|
|
29
|
+
|
|
30
|
+
config = getattr(shell, "config", {})
|
|
31
|
+
return "IPKernelApp" in config
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@contextmanager
|
|
35
|
+
def redirectNativeOutput(
|
|
36
|
+
logPath: str | os.PathLike[str],
|
|
37
|
+
*,
|
|
38
|
+
stream: bool = False,
|
|
39
|
+
) -> Iterator[None]:
|
|
40
|
+
"""Redirect C/C++ stdout and stderr to a file.
|
|
41
|
+
|
|
42
|
+
This is useful in Jupyter where native C++ loggers can interact poorly with
|
|
43
|
+
ipykernel stdout capture. Python stdout/stderr are restored after the block.
|
|
44
|
+
When ``stream`` is true, new log content is also printed as it is written.
|
|
45
|
+
"""
|
|
46
|
+
path = Path(logPath).expanduser()
|
|
47
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
48
|
+
|
|
49
|
+
libc = ctypes.CDLL(None)
|
|
50
|
+
sys.stdout.flush()
|
|
51
|
+
sys.stderr.flush()
|
|
52
|
+
libc.fflush(None)
|
|
53
|
+
|
|
54
|
+
stdoutFd = os.dup(1)
|
|
55
|
+
stderrFd = os.dup(2)
|
|
56
|
+
streamStopEvent: threading.Event | None = None
|
|
57
|
+
streamThread: threading.Thread | None = None
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
with path.open("ab", buffering=0) as nativeLog:
|
|
61
|
+
streamOffset = nativeLog.tell()
|
|
62
|
+
if stream:
|
|
63
|
+
streamStopEvent = threading.Event()
|
|
64
|
+
streamThread = threading.Thread(
|
|
65
|
+
target=_streamLogFile,
|
|
66
|
+
args=(path, streamStopEvent, streamOffset, stdoutFd),
|
|
67
|
+
daemon=True,
|
|
68
|
+
)
|
|
69
|
+
streamThread.start()
|
|
70
|
+
os.dup2(nativeLog.fileno(), 1)
|
|
71
|
+
os.dup2(nativeLog.fileno(), 2)
|
|
72
|
+
yield
|
|
73
|
+
finally:
|
|
74
|
+
sys.stdout.flush()
|
|
75
|
+
sys.stderr.flush()
|
|
76
|
+
libc.fflush(None)
|
|
77
|
+
if streamStopEvent is not None:
|
|
78
|
+
streamStopEvent.set()
|
|
79
|
+
if streamThread is not None:
|
|
80
|
+
streamThread.join(timeout=2.0)
|
|
81
|
+
os.dup2(stdoutFd, 1)
|
|
82
|
+
os.dup2(stderrFd, 2)
|
|
83
|
+
os.close(stdoutFd)
|
|
84
|
+
os.close(stderrFd)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def maybeRedirectNativeOutput(
|
|
88
|
+
logPath: str | os.PathLike[str] | None,
|
|
89
|
+
*,
|
|
90
|
+
stream: bool = True,
|
|
91
|
+
):
|
|
92
|
+
if logPath is None:
|
|
93
|
+
return nullcontext()
|
|
94
|
+
return redirectNativeOutput(logPath, stream=stream)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@contextmanager
|
|
98
|
+
def temporaryRedirectNativeOutput(prefix: str):
|
|
99
|
+
if not isNotebookRuntime():
|
|
100
|
+
yield
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
logPath = (
|
|
104
|
+
Path(tempfile.gettempdir())
|
|
105
|
+
/ f"{prefix}_{os.getpid()}_{uuid.uuid4().hex[:8]}.log"
|
|
106
|
+
)
|
|
107
|
+
try:
|
|
108
|
+
with redirectNativeOutput(logPath):
|
|
109
|
+
yield
|
|
110
|
+
finally:
|
|
111
|
+
try:
|
|
112
|
+
if logPath.exists():
|
|
113
|
+
logContent = logPath.read_text(encoding="utf-8", errors="replace")
|
|
114
|
+
if logContent:
|
|
115
|
+
print(logContent, end="" if logContent.endswith("\n") else "\n")
|
|
116
|
+
finally:
|
|
117
|
+
logPath.unlink(missing_ok=True)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _streamLogFile(
|
|
121
|
+
path: Path,
|
|
122
|
+
stopEvent: threading.Event,
|
|
123
|
+
startOffset: int,
|
|
124
|
+
outputFd: int,
|
|
125
|
+
pollInterval: float = 0.1,
|
|
126
|
+
) -> None:
|
|
127
|
+
"""Print bytes appended to ``path`` until ``stopEvent`` is set."""
|
|
128
|
+
try:
|
|
129
|
+
with path.open("rb") as logFile:
|
|
130
|
+
logFile.seek(startOffset)
|
|
131
|
+
while not stopEvent.is_set():
|
|
132
|
+
_streamAvailableLogBytes(logFile, outputFd)
|
|
133
|
+
time.sleep(pollInterval)
|
|
134
|
+
_streamAvailableLogBytes(logFile, outputFd)
|
|
135
|
+
except OSError:
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _streamAvailableLogBytes(logFile, outputFd: int) -> None:
|
|
140
|
+
while True:
|
|
141
|
+
chunk = logFile.read(8192)
|
|
142
|
+
if not chunk:
|
|
143
|
+
return
|
|
144
|
+
if isNotebookRuntime():
|
|
145
|
+
sys.stdout.write(chunk.decode("utf-8", errors="replace"))
|
|
146
|
+
sys.stdout.flush()
|
|
147
|
+
else:
|
|
148
|
+
os.write(outputFd, chunk)
|