ionworks-api 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ionworks/__init__.py +3 -0
- ionworks/cell_instance.py +201 -0
- ionworks/cell_measurement.py +339 -0
- ionworks/cell_specification.py +108 -0
- ionworks/client.py +126 -0
- ionworks/errors.py +35 -0
- ionworks/job.py +111 -0
- ionworks/models.py +185 -0
- ionworks/pipeline.py +256 -0
- ionworks/simulation.py +371 -0
- ionworks/validators.py +171 -0
- ionworks_api-0.1.0.dist-info/METADATA +318 -0
- ionworks_api-0.1.0.dist-info/RECORD +14 -0
- ionworks_api-0.1.0.dist-info/WHEEL +4 -0
ionworks/pipeline.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic import (
|
|
6
|
+
BaseModel,
|
|
7
|
+
Field,
|
|
8
|
+
ValidationError,
|
|
9
|
+
field_validator,
|
|
10
|
+
model_validator,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from .errors import IonworksError
|
|
14
|
+
from .validators import run_validators_outbound
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _prepare_payload(data: Any) -> Any:
|
|
18
|
+
"""Prepare payload for API submission using outbound validators pipeline."""
|
|
19
|
+
return run_validators_outbound(data)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DataFitConfig(BaseModel):
|
|
23
|
+
objectives: dict[str, Any]
|
|
24
|
+
parameters: dict[str, Any]
|
|
25
|
+
cost: dict[str, Any] | None = None
|
|
26
|
+
optimizer: dict[str, Any] | None = None
|
|
27
|
+
existing_parameters: dict[str, Any] | None = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class EntryConfig(BaseModel):
|
|
31
|
+
values: dict[str, Any]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class BuiltInEntryConfig(BaseModel):
|
|
35
|
+
name: str
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class CalculationConfig(BaseModel):
|
|
39
|
+
calculation: str
|
|
40
|
+
electrode: str | None = None
|
|
41
|
+
method: str | None = None
|
|
42
|
+
existing_parameters: dict[str, Any] | None = None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ValidationConfig(BaseModel):
|
|
46
|
+
objectives: dict[str, Any]
|
|
47
|
+
summary_stats: list[Any]
|
|
48
|
+
existing_parameters: dict[str, Any] | None = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class PipelineConfig(BaseModel):
|
|
52
|
+
project_id: str | None = Field(
|
|
53
|
+
default=None,
|
|
54
|
+
description="The project id to submit the pipeline to. "
|
|
55
|
+
"Can be found in the project settings page. "
|
|
56
|
+
"If not provided, will use PROJECT_ID environment variable.",
|
|
57
|
+
)
|
|
58
|
+
elements: dict[str, Any] = Field(
|
|
59
|
+
description="Dictionary of elements defining the pipeline. The key is the name "
|
|
60
|
+
"of the element and the value is the configuration of the element. "
|
|
61
|
+
)
|
|
62
|
+
name: str | None = Field(default=None, description="The name of the pipeline.")
|
|
63
|
+
description: str | None = Field(
|
|
64
|
+
default=None, description="The description of the pipeline."
|
|
65
|
+
)
|
|
66
|
+
options: dict[str, Any] | None = Field(
|
|
67
|
+
default=None,
|
|
68
|
+
description="Dictionary of options for the pipeline. "
|
|
69
|
+
"Options are used to configure the pipeline behavior. "
|
|
70
|
+
"Available options are: "
|
|
71
|
+
"live_progress_updates: bool ",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
@field_validator("elements", mode="before")
|
|
75
|
+
@classmethod
|
|
76
|
+
def validate_elements_format(cls, v: Any) -> dict[str, Any]:
|
|
77
|
+
"""Validate that elements is a dict, not the old list format."""
|
|
78
|
+
if isinstance(v, list):
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"Pipeline elements must be provided as a dictionary, not a list. "
|
|
81
|
+
"The format has changed from:\n"
|
|
82
|
+
' "elements": [{"name": {...}}, ...]\n'
|
|
83
|
+
"to:\n"
|
|
84
|
+
' "elements": {"name": {...}, ...}\n'
|
|
85
|
+
"Please update your pipeline configuration."
|
|
86
|
+
)
|
|
87
|
+
if not isinstance(v, dict):
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"Pipeline elements must be a dictionary, got {type(v).__name__}"
|
|
90
|
+
)
|
|
91
|
+
return v
|
|
92
|
+
|
|
93
|
+
@model_validator(mode="after")
|
|
94
|
+
def set_defaults(self) -> "PipelineConfig":
|
|
95
|
+
"""Set project_id from environment variable if not provided and defaults."""
|
|
96
|
+
if self.project_id is None:
|
|
97
|
+
env_project_id = os.getenv("PROJECT_ID")
|
|
98
|
+
if env_project_id is None:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
"project_id is required. Either provide it in the config "
|
|
101
|
+
"or set the PROJECT_ID environment variable."
|
|
102
|
+
)
|
|
103
|
+
self.project_id = env_project_id
|
|
104
|
+
# Ensure options is never None to avoid 422 errors
|
|
105
|
+
if self.options is None:
|
|
106
|
+
self.options = {}
|
|
107
|
+
return self
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class DataFitResponse(BaseModel):
|
|
111
|
+
parameter_values: dict[str, Any]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class CalculationResponse(BaseModel):
|
|
115
|
+
parameter_values: dict[str, Any]
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class ValidationResponse(BaseModel):
|
|
119
|
+
validation_results: dict[str, Any]
|
|
120
|
+
summary_stats: dict[str, list[Any]]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class EntryResponse(BaseModel):
|
|
124
|
+
parameter_values: dict[str, Any]
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class PipelineSubmissionResponse(BaseModel):
|
|
128
|
+
id: str
|
|
129
|
+
name: str
|
|
130
|
+
description: str | None = None
|
|
131
|
+
status: str
|
|
132
|
+
error: str | None = None
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class PipelineResponse(BaseModel):
|
|
136
|
+
result: dict[str, Any]
|
|
137
|
+
element_results: dict[str, Any]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class PipelineClient:
|
|
141
|
+
def __init__(self, client: Any):
|
|
142
|
+
self.client = client
|
|
143
|
+
|
|
144
|
+
def create(self, config: dict[str, Any]) -> PipelineSubmissionResponse:
|
|
145
|
+
"""Run a complete pipeline with the given configuration.
|
|
146
|
+
|
|
147
|
+
Parameters
|
|
148
|
+
----------
|
|
149
|
+
config : dict[str, Any]
|
|
150
|
+
Dictionary containing configuration for the pipeline.
|
|
151
|
+
|
|
152
|
+
Returns
|
|
153
|
+
-------
|
|
154
|
+
PipelineSubmissionResponse
|
|
155
|
+
The pipeline submission response.
|
|
156
|
+
|
|
157
|
+
Raises
|
|
158
|
+
------
|
|
159
|
+
ValueError
|
|
160
|
+
If the configuration is invalid.
|
|
161
|
+
"""
|
|
162
|
+
try:
|
|
163
|
+
validated_config = PipelineConfig(**config)
|
|
164
|
+
payload = _prepare_payload(validated_config.model_dump())
|
|
165
|
+
response_data = self.client.post("/pipelines", payload)
|
|
166
|
+
return PipelineSubmissionResponse(**response_data)
|
|
167
|
+
except ValidationError as e:
|
|
168
|
+
raise ValueError(f"Invalid pipeline configuration: {e}") from e
|
|
169
|
+
except IonworksError as e:
|
|
170
|
+
error_msg = str(e.message)
|
|
171
|
+
# Check for invalid UUID format in project_id
|
|
172
|
+
uuid_match = re.search(
|
|
173
|
+
r'invalid input syntax for type uuid: "([^"]*)"', error_msg
|
|
174
|
+
)
|
|
175
|
+
if uuid_match:
|
|
176
|
+
invalid_id = uuid_match.group(1)
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"Invalid project_id format: '{invalid_id}' is not a valid UUID. "
|
|
179
|
+
"Please provide a valid project ID from your project settings page."
|
|
180
|
+
) from e
|
|
181
|
+
# Check for row-level security policy violation (project not accessible)
|
|
182
|
+
if "violates row-level security policy" in error_msg:
|
|
183
|
+
project_id = validated_config.project_id
|
|
184
|
+
raise ValueError(
|
|
185
|
+
f"Access denied: The project '{project_id}' is not accessible "
|
|
186
|
+
"with your API key. Please verify that your API key has access "
|
|
187
|
+
"to this project."
|
|
188
|
+
) from e
|
|
189
|
+
# Re-raise original error for other cases
|
|
190
|
+
raise
|
|
191
|
+
|
|
192
|
+
def list(self, project_id: str | None = None) -> list[PipelineSubmissionResponse]:
|
|
193
|
+
"""List all pipelines.
|
|
194
|
+
|
|
195
|
+
Parameters
|
|
196
|
+
----------
|
|
197
|
+
project_id : str, optional
|
|
198
|
+
The project id to filter pipelines. If not provided, will use
|
|
199
|
+
PROJECT_ID environment variable.
|
|
200
|
+
|
|
201
|
+
Returns
|
|
202
|
+
-------
|
|
203
|
+
list[PipelineSubmissionResponse]
|
|
204
|
+
List of pipeline submission responses.
|
|
205
|
+
"""
|
|
206
|
+
if project_id is None:
|
|
207
|
+
project_id = os.getenv("PROJECT_ID")
|
|
208
|
+
if project_id is None:
|
|
209
|
+
raise ValueError(
|
|
210
|
+
"project_id is required. Either provide it as an argument "
|
|
211
|
+
"or set the PROJECT_ID environment variable."
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
endpoint = f"/pipelines?project_id={project_id}"
|
|
215
|
+
try:
|
|
216
|
+
response_data = self.client.get(endpoint)
|
|
217
|
+
if not isinstance(response_data, list):
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"Unexpected response format from {endpoint}: expected a list, "
|
|
220
|
+
f"got {type(response_data).__name__}"
|
|
221
|
+
)
|
|
222
|
+
return [PipelineSubmissionResponse(**item) for item in response_data]
|
|
223
|
+
except ValidationError as e:
|
|
224
|
+
raise ValueError(f"Invalid item format in list from {endpoint}: {e}") from e
|
|
225
|
+
|
|
226
|
+
def get(self, job_id: str) -> PipelineSubmissionResponse:
|
|
227
|
+
"""Get the pipeline response for the given job id.
|
|
228
|
+
|
|
229
|
+
Parameters
|
|
230
|
+
----------
|
|
231
|
+
job_id : str
|
|
232
|
+
The job id.
|
|
233
|
+
|
|
234
|
+
Returns
|
|
235
|
+
-------
|
|
236
|
+
PipelineSubmissionResponse
|
|
237
|
+
The pipeline submission response.
|
|
238
|
+
"""
|
|
239
|
+
response_data = self.client.get(f"/pipelines/{job_id}")
|
|
240
|
+
return PipelineSubmissionResponse(**response_data)
|
|
241
|
+
|
|
242
|
+
def result(self, job_id: str) -> PipelineResponse:
|
|
243
|
+
"""Get the result for the given job id.
|
|
244
|
+
|
|
245
|
+
Parameters
|
|
246
|
+
----------
|
|
247
|
+
job_id : str
|
|
248
|
+
The job id.
|
|
249
|
+
|
|
250
|
+
Returns
|
|
251
|
+
-------
|
|
252
|
+
PipelineResponse
|
|
253
|
+
The pipeline results.
|
|
254
|
+
"""
|
|
255
|
+
response_data = self.client.get(f"/pipelines/{job_id}/result")
|
|
256
|
+
return PipelineResponse(**response_data)
|
ionworks/simulation.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, timedelta
|
|
4
|
+
import time
|
|
5
|
+
from typing import Any, List, cast
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field, ValidationError
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class QuickModelConfig(BaseModel):
|
|
11
|
+
"""Quick model configuration for protocol-based simulations."""
|
|
12
|
+
|
|
13
|
+
capacity: float = Field(default=1.0, description="Cell capacity in Ah")
|
|
14
|
+
chemistry: str = Field(default="NMC/Graphite", description="Chemistry name")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ProtocolExperimentConfig(BaseModel):
|
|
18
|
+
"""Protocol experiment configuration."""
|
|
19
|
+
|
|
20
|
+
protocol: str = Field(description="YAML protocol string (UCP format)")
|
|
21
|
+
name: str = Field(description="Protocol name for template naming")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ProtocolSimulationRequest(BaseModel):
|
|
25
|
+
"""Request model for single protocol-based simulation."""
|
|
26
|
+
|
|
27
|
+
parameterized_model: Any = Field(
|
|
28
|
+
description=(
|
|
29
|
+
"Model can be: quick_model dict, full model dict, or model ID string"
|
|
30
|
+
)
|
|
31
|
+
)
|
|
32
|
+
protocol_experiment: ProtocolExperimentConfig = Field(
|
|
33
|
+
description="Protocol experiment configuration"
|
|
34
|
+
)
|
|
35
|
+
experiment_parameters: dict[str, float] | None = Field(
|
|
36
|
+
default=None,
|
|
37
|
+
description=("Experiment parameters for any inputs in the protocol."),
|
|
38
|
+
)
|
|
39
|
+
design_parameters: dict[str, float] | None = Field(
|
|
40
|
+
default=None, description="Design parameters for the simulation"
|
|
41
|
+
)
|
|
42
|
+
max_backward_jumps: int | None = Field(
|
|
43
|
+
default=None,
|
|
44
|
+
description="Maximum backward jumps allowed (for goto statements)",
|
|
45
|
+
)
|
|
46
|
+
study_id: str | None = Field(default=None, description="Optional study UUID")
|
|
47
|
+
extra_variables: list[str] | None = Field(
|
|
48
|
+
default=None,
|
|
49
|
+
description=(
|
|
50
|
+
"Optional list of extra variables to include in simulation output "
|
|
51
|
+
"(e.g., ['Negative electrode potential [V]', 'Positive electrode "
|
|
52
|
+
"potential [V]']). If provided, these override any extra variables "
|
|
53
|
+
"defined in the experiment template."
|
|
54
|
+
),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class DOERow(BaseModel):
|
|
59
|
+
"""Design of experiments row configuration."""
|
|
60
|
+
|
|
61
|
+
type: str = Field(description="Type: 'range', 'discrete', or 'normal'")
|
|
62
|
+
name: str = Field(description="Parameter name")
|
|
63
|
+
# For range type
|
|
64
|
+
min: float | None = Field(default=None, description="Minimum value")
|
|
65
|
+
max: float | None = Field(default=None, description="Maximum value")
|
|
66
|
+
count: int | None = Field(
|
|
67
|
+
default=None, description="Number of samples (for grid/random)"
|
|
68
|
+
)
|
|
69
|
+
# For discrete type
|
|
70
|
+
values: list[float] | None = Field(default=None, description="Discrete values")
|
|
71
|
+
# For normal type
|
|
72
|
+
mean: float | None = Field(default=None, description="Mean value")
|
|
73
|
+
std: float | None = Field(default=None, description="Standard deviation")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class DesignParametersDOE(BaseModel):
|
|
77
|
+
"""Design of experiments configuration."""
|
|
78
|
+
|
|
79
|
+
sampling: str = Field(
|
|
80
|
+
description="Sampling strategy: 'grid', 'random', or 'latin_hypercube'"
|
|
81
|
+
)
|
|
82
|
+
rows: list[DOERow] = Field(description="DOE row configurations")
|
|
83
|
+
count: int | None = Field(
|
|
84
|
+
default=None, description="Total count for non-grid sampling"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class ProtocolSimulationBatchRequest(BaseModel):
|
|
89
|
+
"""Request model for batch protocol-based simulation."""
|
|
90
|
+
|
|
91
|
+
parameterized_model: Any = Field(
|
|
92
|
+
description=(
|
|
93
|
+
"Model can be: quick_model dict, full model dict, or model ID string"
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
protocol_experiment: ProtocolExperimentConfig = Field(
|
|
97
|
+
description="Protocol experiment configuration"
|
|
98
|
+
)
|
|
99
|
+
design_parameters_doe: DesignParametersDOE = Field(
|
|
100
|
+
description="Design of experiments configuration"
|
|
101
|
+
)
|
|
102
|
+
experiment_parameters: dict[str, float] | None = Field(
|
|
103
|
+
default=None,
|
|
104
|
+
description=("Experiment parameters for any inputs in the protocol."),
|
|
105
|
+
)
|
|
106
|
+
max_backward_jumps: int | None = Field(
|
|
107
|
+
default=None,
|
|
108
|
+
description="Maximum backward jumps allowed (for goto statements)",
|
|
109
|
+
)
|
|
110
|
+
study_id: str | None = Field(default=None, description="Optional study UUID")
|
|
111
|
+
extra_variables: list[str] | None = Field(
|
|
112
|
+
default=None,
|
|
113
|
+
description=(
|
|
114
|
+
"Optional list of extra variables to include in simulation output "
|
|
115
|
+
"(e.g., ['Negative electrode potential [V]', 'Positive electrode "
|
|
116
|
+
"potential [V]']). If provided, these override any extra variables "
|
|
117
|
+
"defined in the experiment template."
|
|
118
|
+
),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class SimulationResponse(BaseModel):
|
|
123
|
+
"""Response model for simulation creation."""
|
|
124
|
+
|
|
125
|
+
simulation_id: str = Field(description="Simulation UUID")
|
|
126
|
+
job_id: str = Field(description="Job UUID")
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class SimulationClient:
|
|
130
|
+
"""Client for running simulations."""
|
|
131
|
+
|
|
132
|
+
def __init__(self, client: Any):
|
|
133
|
+
self.client = client
|
|
134
|
+
|
|
135
|
+
def protocol(self, config: dict[str, Any]) -> SimulationResponse:
|
|
136
|
+
"""Create a single protocol-based simulation.
|
|
137
|
+
|
|
138
|
+
Parameters
|
|
139
|
+
----------
|
|
140
|
+
config : dict[str, Any]
|
|
141
|
+
Configuration dictionary containing:
|
|
142
|
+
- parameterized_model: quick_model dict, full model dict, or model ID string
|
|
143
|
+
- protocol_experiment: ProtocolExperimentConfig dict with protocol
|
|
144
|
+
and name fields
|
|
145
|
+
- experiment_parameters: Optional ProtocolExperimentParameters dict
|
|
146
|
+
with initial_soc and initial_temperature
|
|
147
|
+
- design_parameters: Optional dict[str, float]
|
|
148
|
+
- max_backward_jumps: Optional int
|
|
149
|
+
- study_id: Optional str
|
|
150
|
+
- extra_variables: Optional list[str] - Extra variables to include
|
|
151
|
+
in simulation output
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
SimulationResponse
|
|
156
|
+
Response containing simulation_id and job_id.
|
|
157
|
+
|
|
158
|
+
Raises
|
|
159
|
+
------
|
|
160
|
+
ValueError
|
|
161
|
+
If the configuration is invalid.
|
|
162
|
+
"""
|
|
163
|
+
endpoint = "/simulations/protocol"
|
|
164
|
+
try:
|
|
165
|
+
validated_config = ProtocolSimulationRequest(**config)
|
|
166
|
+
response_data = self.client.post(
|
|
167
|
+
endpoint, json_payload=validated_config.model_dump(exclude_none=True)
|
|
168
|
+
)
|
|
169
|
+
return SimulationResponse(**response_data)
|
|
170
|
+
except ValidationError as e:
|
|
171
|
+
raise ValueError(f"Invalid protocol simulation configuration: {e}") from e
|
|
172
|
+
|
|
173
|
+
def protocol_batch(self, config: dict[str, Any]) -> list[SimulationResponse]:
|
|
174
|
+
"""Create multiple protocol-based simulations using DOE.
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
config : dict[str, Any]
|
|
179
|
+
Configuration dictionary containing:
|
|
180
|
+
- parameterized_model: quick_model dict, full model dict, or model ID string
|
|
181
|
+
- protocol_experiment: ProtocolExperimentConfig dict with protocol
|
|
182
|
+
and name fields
|
|
183
|
+
- design_parameters_doe: DesignParametersDOE dict
|
|
184
|
+
- experiment_parameters: Optional ProtocolExperimentParameters dict
|
|
185
|
+
- max_backward_jumps: Optional int
|
|
186
|
+
- study_id: Optional str
|
|
187
|
+
- extra_variables: Optional list[str] - Extra variables to include
|
|
188
|
+
in simulation output
|
|
189
|
+
|
|
190
|
+
Returns
|
|
191
|
+
-------
|
|
192
|
+
list[SimulationResponse]
|
|
193
|
+
List of responses, each containing simulation_id and job_id.
|
|
194
|
+
|
|
195
|
+
Raises
|
|
196
|
+
------
|
|
197
|
+
ValueError
|
|
198
|
+
If the configuration is invalid.
|
|
199
|
+
"""
|
|
200
|
+
endpoint = "/simulations/protocol/batch"
|
|
201
|
+
try:
|
|
202
|
+
validated_config = ProtocolSimulationBatchRequest(**config)
|
|
203
|
+
response_data = self.client.post(
|
|
204
|
+
endpoint, json_payload=validated_config.model_dump(exclude_none=True)
|
|
205
|
+
)
|
|
206
|
+
if not isinstance(response_data, list):
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"Unexpected response format from {endpoint}: expected a "
|
|
209
|
+
f"list, got {type(response_data).__name__}"
|
|
210
|
+
)
|
|
211
|
+
return [SimulationResponse(**item) for item in response_data]
|
|
212
|
+
except ValidationError as e:
|
|
213
|
+
raise ValueError(
|
|
214
|
+
f"Invalid batch protocol simulation configuration: {e}"
|
|
215
|
+
) from e
|
|
216
|
+
|
|
217
|
+
def list(self) -> list[dict[str, Any]]:
|
|
218
|
+
"""List all simulations for the current user.
|
|
219
|
+
|
|
220
|
+
Returns
|
|
221
|
+
-------
|
|
222
|
+
list[dict[str, Any]]
|
|
223
|
+
List of simulation objects with joined model and experiment data.
|
|
224
|
+
"""
|
|
225
|
+
endpoint = "/simulations"
|
|
226
|
+
response_data = self.client.get(endpoint)
|
|
227
|
+
if not isinstance(response_data, list):
|
|
228
|
+
raise ValueError(
|
|
229
|
+
f"Unexpected response format from {endpoint}: expected a list, "
|
|
230
|
+
f"got {type(response_data).__name__}"
|
|
231
|
+
)
|
|
232
|
+
return response_data
|
|
233
|
+
|
|
234
|
+
def get(self, simulation_id: str) -> dict[str, Any]:
|
|
235
|
+
"""Get a specific simulation by ID.
|
|
236
|
+
|
|
237
|
+
Parameters
|
|
238
|
+
----------
|
|
239
|
+
simulation_id : str
|
|
240
|
+
The UUID of the simulation to retrieve.
|
|
241
|
+
|
|
242
|
+
Returns
|
|
243
|
+
-------
|
|
244
|
+
dict[str, Any]
|
|
245
|
+
Simulation object with full joined data including model, experiment,
|
|
246
|
+
and simulation_data (null if not completed).
|
|
247
|
+
"""
|
|
248
|
+
endpoint = f"/simulations/{simulation_id}"
|
|
249
|
+
response_data = self.client.get(endpoint)
|
|
250
|
+
return cast(dict[str, Any], response_data)
|
|
251
|
+
|
|
252
|
+
def get_result(self, simulation_id: str) -> dict[str, Any]:
|
|
253
|
+
"""Get simulation data/result for a completed simulation.
|
|
254
|
+
|
|
255
|
+
Parameters
|
|
256
|
+
----------
|
|
257
|
+
simulation_id : str
|
|
258
|
+
The UUID of the simulation.
|
|
259
|
+
|
|
260
|
+
Returns
|
|
261
|
+
-------
|
|
262
|
+
dict[str, Any]
|
|
263
|
+
Simulation data object containing time_series, steps, and metrics.
|
|
264
|
+
Returns 404 if simulation hasn't completed yet.
|
|
265
|
+
|
|
266
|
+
Raises
|
|
267
|
+
------
|
|
268
|
+
Exception
|
|
269
|
+
If simulation data not found (simulation may not be completed yet).
|
|
270
|
+
The client will raise an appropriate error for 404 responses.
|
|
271
|
+
"""
|
|
272
|
+
endpoint = f"/simulations/{simulation_id}/result"
|
|
273
|
+
response_data = self.client.get(endpoint)
|
|
274
|
+
return cast(dict[str, Any], response_data)
|
|
275
|
+
|
|
276
|
+
def wait_for_completion(
|
|
277
|
+
self,
|
|
278
|
+
simulation_id: str | List[str], # List to avoid conflict with list() method
|
|
279
|
+
timeout: int = 60,
|
|
280
|
+
poll_interval: int = 2,
|
|
281
|
+
verbose: bool = True,
|
|
282
|
+
) -> dict[str, Any] | List[dict[str, Any]]:
|
|
283
|
+
"""Wait for simulation(s) to complete by polling until done or timeout.
|
|
284
|
+
|
|
285
|
+
Parameters
|
|
286
|
+
----------
|
|
287
|
+
simulation_id : str | list[str]
|
|
288
|
+
Single simulation ID or list of simulation IDs to wait for.
|
|
289
|
+
timeout : int, optional
|
|
290
|
+
Maximum time to wait in seconds (default: 60).
|
|
291
|
+
poll_interval : int, optional
|
|
292
|
+
Time between polls in seconds (default: 2).
|
|
293
|
+
verbose : bool, optional
|
|
294
|
+
Whether to print status updates (default: True).
|
|
295
|
+
|
|
296
|
+
Returns
|
|
297
|
+
-------
|
|
298
|
+
dict[str, Any] | list[dict[str, Any]]
|
|
299
|
+
Completed simulation(s). Returns single dict if single ID provided,
|
|
300
|
+
list of dicts if list of IDs provided. Only returns completed
|
|
301
|
+
simulations if timeout is reached.
|
|
302
|
+
|
|
303
|
+
Raises
|
|
304
|
+
------
|
|
305
|
+
TimeoutError
|
|
306
|
+
If timeout is reached before all simulations complete.
|
|
307
|
+
"""
|
|
308
|
+
is_single = isinstance(simulation_id, str)
|
|
309
|
+
if is_single:
|
|
310
|
+
simulation_ids: List[str] = [simulation_id] # type: ignore[list-item]
|
|
311
|
+
else:
|
|
312
|
+
simulation_ids = simulation_id # type: ignore[assignment]
|
|
313
|
+
timeout_delta = timedelta(seconds=timeout)
|
|
314
|
+
start_time = datetime.now()
|
|
315
|
+
completed_simulations: dict[str, dict[str, Any]] = {}
|
|
316
|
+
|
|
317
|
+
if verbose:
|
|
318
|
+
print(f"Polling for {len(simulation_ids)} simulation(s) completion...")
|
|
319
|
+
|
|
320
|
+
while datetime.now() - start_time < timeout_delta:
|
|
321
|
+
completed_count = len(completed_simulations)
|
|
322
|
+
for sim_id in simulation_ids:
|
|
323
|
+
if sim_id in completed_simulations:
|
|
324
|
+
continue
|
|
325
|
+
try:
|
|
326
|
+
simulation = self.get(sim_id)
|
|
327
|
+
if simulation.get("simulation_data") is not None:
|
|
328
|
+
completed_simulations[sim_id] = simulation
|
|
329
|
+
completed_count += 1
|
|
330
|
+
except Exception:
|
|
331
|
+
# Continue polling if there's an error
|
|
332
|
+
pass
|
|
333
|
+
|
|
334
|
+
elapsed = (datetime.now() - start_time).seconds
|
|
335
|
+
if verbose:
|
|
336
|
+
print(
|
|
337
|
+
f" Status: {completed_count}/{len(simulation_ids)} completed "
|
|
338
|
+
f"(elapsed: {elapsed}s)"
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
if len(completed_simulations) == len(simulation_ids):
|
|
342
|
+
if verbose:
|
|
343
|
+
print("All simulations completed!")
|
|
344
|
+
break
|
|
345
|
+
|
|
346
|
+
time.sleep(poll_interval)
|
|
347
|
+
else:
|
|
348
|
+
# Timeout reached
|
|
349
|
+
if verbose:
|
|
350
|
+
print(
|
|
351
|
+
f"Timeout: Only {len(completed_simulations)}/{len(simulation_ids)} "
|
|
352
|
+
f"simulations completed within {timeout} seconds"
|
|
353
|
+
)
|
|
354
|
+
if not completed_simulations:
|
|
355
|
+
raise TimeoutError(f"No simulations completed within {timeout} seconds")
|
|
356
|
+
|
|
357
|
+
# Return results in the same format as input
|
|
358
|
+
if is_single:
|
|
359
|
+
if simulation_ids[0] not in completed_simulations:
|
|
360
|
+
raise TimeoutError(
|
|
361
|
+
f"Simulation {simulation_ids[0]} did not complete within "
|
|
362
|
+
f"{timeout} seconds"
|
|
363
|
+
)
|
|
364
|
+
return completed_simulations[simulation_ids[0]]
|
|
365
|
+
else:
|
|
366
|
+
# Return list of completed simulations in order
|
|
367
|
+
return [
|
|
368
|
+
completed_simulations[sim_id]
|
|
369
|
+
for sim_id in simulation_ids
|
|
370
|
+
if sim_id in completed_simulations
|
|
371
|
+
]
|