ionworks-api 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ionworks/pipeline.py ADDED
@@ -0,0 +1,256 @@
1
+ import os
2
+ import re
3
+ from typing import Any
4
+
5
+ from pydantic import (
6
+ BaseModel,
7
+ Field,
8
+ ValidationError,
9
+ field_validator,
10
+ model_validator,
11
+ )
12
+
13
+ from .errors import IonworksError
14
+ from .validators import run_validators_outbound
15
+
16
+
17
+ def _prepare_payload(data: Any) -> Any:
18
+ """Prepare payload for API submission using outbound validators pipeline."""
19
+ return run_validators_outbound(data)
20
+
21
+
22
+ class DataFitConfig(BaseModel):
23
+ objectives: dict[str, Any]
24
+ parameters: dict[str, Any]
25
+ cost: dict[str, Any] | None = None
26
+ optimizer: dict[str, Any] | None = None
27
+ existing_parameters: dict[str, Any] | None = None
28
+
29
+
30
+ class EntryConfig(BaseModel):
31
+ values: dict[str, Any]
32
+
33
+
34
+ class BuiltInEntryConfig(BaseModel):
35
+ name: str
36
+
37
+
38
+ class CalculationConfig(BaseModel):
39
+ calculation: str
40
+ electrode: str | None = None
41
+ method: str | None = None
42
+ existing_parameters: dict[str, Any] | None = None
43
+
44
+
45
+ class ValidationConfig(BaseModel):
46
+ objectives: dict[str, Any]
47
+ summary_stats: list[Any]
48
+ existing_parameters: dict[str, Any] | None = None
49
+
50
+
51
+ class PipelineConfig(BaseModel):
52
+ project_id: str | None = Field(
53
+ default=None,
54
+ description="The project id to submit the pipeline to. "
55
+ "Can be found in the project settings page. "
56
+ "If not provided, will use PROJECT_ID environment variable.",
57
+ )
58
+ elements: dict[str, Any] = Field(
59
+ description="Dictionary of elements defining the pipeline. The key is the name "
60
+ "of the element and the value is the configuration of the element. "
61
+ )
62
+ name: str | None = Field(default=None, description="The name of the pipeline.")
63
+ description: str | None = Field(
64
+ default=None, description="The description of the pipeline."
65
+ )
66
+ options: dict[str, Any] | None = Field(
67
+ default=None,
68
+ description="Dictionary of options for the pipeline. "
69
+ "Options are used to configure the pipeline behavior. "
70
+ "Available options are: "
71
+ "live_progress_updates: bool ",
72
+ )
73
+
74
+ @field_validator("elements", mode="before")
75
+ @classmethod
76
+ def validate_elements_format(cls, v: Any) -> dict[str, Any]:
77
+ """Validate that elements is a dict, not the old list format."""
78
+ if isinstance(v, list):
79
+ raise ValueError(
80
+ "Pipeline elements must be provided as a dictionary, not a list. "
81
+ "The format has changed from:\n"
82
+ ' "elements": [{"name": {...}}, ...]\n'
83
+ "to:\n"
84
+ ' "elements": {"name": {...}, ...}\n'
85
+ "Please update your pipeline configuration."
86
+ )
87
+ if not isinstance(v, dict):
88
+ raise ValueError(
89
+ f"Pipeline elements must be a dictionary, got {type(v).__name__}"
90
+ )
91
+ return v
92
+
93
+ @model_validator(mode="after")
94
+ def set_defaults(self) -> "PipelineConfig":
95
+ """Set project_id from environment variable if not provided and defaults."""
96
+ if self.project_id is None:
97
+ env_project_id = os.getenv("PROJECT_ID")
98
+ if env_project_id is None:
99
+ raise ValueError(
100
+ "project_id is required. Either provide it in the config "
101
+ "or set the PROJECT_ID environment variable."
102
+ )
103
+ self.project_id = env_project_id
104
+ # Ensure options is never None to avoid 422 errors
105
+ if self.options is None:
106
+ self.options = {}
107
+ return self
108
+
109
+
110
+ class DataFitResponse(BaseModel):
111
+ parameter_values: dict[str, Any]
112
+
113
+
114
+ class CalculationResponse(BaseModel):
115
+ parameter_values: dict[str, Any]
116
+
117
+
118
+ class ValidationResponse(BaseModel):
119
+ validation_results: dict[str, Any]
120
+ summary_stats: dict[str, list[Any]]
121
+
122
+
123
+ class EntryResponse(BaseModel):
124
+ parameter_values: dict[str, Any]
125
+
126
+
127
+ class PipelineSubmissionResponse(BaseModel):
128
+ id: str
129
+ name: str
130
+ description: str | None = None
131
+ status: str
132
+ error: str | None = None
133
+
134
+
135
+ class PipelineResponse(BaseModel):
136
+ result: dict[str, Any]
137
+ element_results: dict[str, Any]
138
+
139
+
140
+ class PipelineClient:
141
+ def __init__(self, client: Any):
142
+ self.client = client
143
+
144
+ def create(self, config: dict[str, Any]) -> PipelineSubmissionResponse:
145
+ """Run a complete pipeline with the given configuration.
146
+
147
+ Parameters
148
+ ----------
149
+ config : dict[str, Any]
150
+ Dictionary containing configuration for the pipeline.
151
+
152
+ Returns
153
+ -------
154
+ PipelineSubmissionResponse
155
+ The pipeline submission response.
156
+
157
+ Raises
158
+ ------
159
+ ValueError
160
+ If the configuration is invalid.
161
+ """
162
+ try:
163
+ validated_config = PipelineConfig(**config)
164
+ payload = _prepare_payload(validated_config.model_dump())
165
+ response_data = self.client.post("/pipelines", payload)
166
+ return PipelineSubmissionResponse(**response_data)
167
+ except ValidationError as e:
168
+ raise ValueError(f"Invalid pipeline configuration: {e}") from e
169
+ except IonworksError as e:
170
+ error_msg = str(e.message)
171
+ # Check for invalid UUID format in project_id
172
+ uuid_match = re.search(
173
+ r'invalid input syntax for type uuid: "([^"]*)"', error_msg
174
+ )
175
+ if uuid_match:
176
+ invalid_id = uuid_match.group(1)
177
+ raise ValueError(
178
+ f"Invalid project_id format: '{invalid_id}' is not a valid UUID. "
179
+ "Please provide a valid project ID from your project settings page."
180
+ ) from e
181
+ # Check for row-level security policy violation (project not accessible)
182
+ if "violates row-level security policy" in error_msg:
183
+ project_id = validated_config.project_id
184
+ raise ValueError(
185
+ f"Access denied: The project '{project_id}' is not accessible "
186
+ "with your API key. Please verify that your API key has access "
187
+ "to this project."
188
+ ) from e
189
+ # Re-raise original error for other cases
190
+ raise
191
+
192
+ def list(self, project_id: str | None = None) -> list[PipelineSubmissionResponse]:
193
+ """List all pipelines.
194
+
195
+ Parameters
196
+ ----------
197
+ project_id : str, optional
198
+ The project id to filter pipelines. If not provided, will use
199
+ PROJECT_ID environment variable.
200
+
201
+ Returns
202
+ -------
203
+ list[PipelineSubmissionResponse]
204
+ List of pipeline submission responses.
205
+ """
206
+ if project_id is None:
207
+ project_id = os.getenv("PROJECT_ID")
208
+ if project_id is None:
209
+ raise ValueError(
210
+ "project_id is required. Either provide it as an argument "
211
+ "or set the PROJECT_ID environment variable."
212
+ )
213
+
214
+ endpoint = f"/pipelines?project_id={project_id}"
215
+ try:
216
+ response_data = self.client.get(endpoint)
217
+ if not isinstance(response_data, list):
218
+ raise ValueError(
219
+ f"Unexpected response format from {endpoint}: expected a list, "
220
+ f"got {type(response_data).__name__}"
221
+ )
222
+ return [PipelineSubmissionResponse(**item) for item in response_data]
223
+ except ValidationError as e:
224
+ raise ValueError(f"Invalid item format in list from {endpoint}: {e}") from e
225
+
226
+ def get(self, job_id: str) -> PipelineSubmissionResponse:
227
+ """Get the pipeline response for the given job id.
228
+
229
+ Parameters
230
+ ----------
231
+ job_id : str
232
+ The job id.
233
+
234
+ Returns
235
+ -------
236
+ PipelineSubmissionResponse
237
+ The pipeline submission response.
238
+ """
239
+ response_data = self.client.get(f"/pipelines/{job_id}")
240
+ return PipelineSubmissionResponse(**response_data)
241
+
242
+ def result(self, job_id: str) -> PipelineResponse:
243
+ """Get the result for the given job id.
244
+
245
+ Parameters
246
+ ----------
247
+ job_id : str
248
+ The job id.
249
+
250
+ Returns
251
+ -------
252
+ PipelineResponse
253
+ The pipeline results.
254
+ """
255
+ response_data = self.client.get(f"/pipelines/{job_id}/result")
256
+ return PipelineResponse(**response_data)
ionworks/simulation.py ADDED
@@ -0,0 +1,371 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime, timedelta
4
+ import time
5
+ from typing import Any, List, cast
6
+
7
+ from pydantic import BaseModel, Field, ValidationError
8
+
9
+
10
+ class QuickModelConfig(BaseModel):
11
+ """Quick model configuration for protocol-based simulations."""
12
+
13
+ capacity: float = Field(default=1.0, description="Cell capacity in Ah")
14
+ chemistry: str = Field(default="NMC/Graphite", description="Chemistry name")
15
+
16
+
17
+ class ProtocolExperimentConfig(BaseModel):
18
+ """Protocol experiment configuration."""
19
+
20
+ protocol: str = Field(description="YAML protocol string (UCP format)")
21
+ name: str = Field(description="Protocol name for template naming")
22
+
23
+
24
+ class ProtocolSimulationRequest(BaseModel):
25
+ """Request model for single protocol-based simulation."""
26
+
27
+ parameterized_model: Any = Field(
28
+ description=(
29
+ "Model can be: quick_model dict, full model dict, or model ID string"
30
+ )
31
+ )
32
+ protocol_experiment: ProtocolExperimentConfig = Field(
33
+ description="Protocol experiment configuration"
34
+ )
35
+ experiment_parameters: dict[str, float] | None = Field(
36
+ default=None,
37
+ description=("Experiment parameters for any inputs in the protocol."),
38
+ )
39
+ design_parameters: dict[str, float] | None = Field(
40
+ default=None, description="Design parameters for the simulation"
41
+ )
42
+ max_backward_jumps: int | None = Field(
43
+ default=None,
44
+ description="Maximum backward jumps allowed (for goto statements)",
45
+ )
46
+ study_id: str | None = Field(default=None, description="Optional study UUID")
47
+ extra_variables: list[str] | None = Field(
48
+ default=None,
49
+ description=(
50
+ "Optional list of extra variables to include in simulation output "
51
+ "(e.g., ['Negative electrode potential [V]', 'Positive electrode "
52
+ "potential [V]']). If provided, these override any extra variables "
53
+ "defined in the experiment template."
54
+ ),
55
+ )
56
+
57
+
58
+ class DOERow(BaseModel):
59
+ """Design of experiments row configuration."""
60
+
61
+ type: str = Field(description="Type: 'range', 'discrete', or 'normal'")
62
+ name: str = Field(description="Parameter name")
63
+ # For range type
64
+ min: float | None = Field(default=None, description="Minimum value")
65
+ max: float | None = Field(default=None, description="Maximum value")
66
+ count: int | None = Field(
67
+ default=None, description="Number of samples (for grid/random)"
68
+ )
69
+ # For discrete type
70
+ values: list[float] | None = Field(default=None, description="Discrete values")
71
+ # For normal type
72
+ mean: float | None = Field(default=None, description="Mean value")
73
+ std: float | None = Field(default=None, description="Standard deviation")
74
+
75
+
76
+ class DesignParametersDOE(BaseModel):
77
+ """Design of experiments configuration."""
78
+
79
+ sampling: str = Field(
80
+ description="Sampling strategy: 'grid', 'random', or 'latin_hypercube'"
81
+ )
82
+ rows: list[DOERow] = Field(description="DOE row configurations")
83
+ count: int | None = Field(
84
+ default=None, description="Total count for non-grid sampling"
85
+ )
86
+
87
+
88
+ class ProtocolSimulationBatchRequest(BaseModel):
89
+ """Request model for batch protocol-based simulation."""
90
+
91
+ parameterized_model: Any = Field(
92
+ description=(
93
+ "Model can be: quick_model dict, full model dict, or model ID string"
94
+ )
95
+ )
96
+ protocol_experiment: ProtocolExperimentConfig = Field(
97
+ description="Protocol experiment configuration"
98
+ )
99
+ design_parameters_doe: DesignParametersDOE = Field(
100
+ description="Design of experiments configuration"
101
+ )
102
+ experiment_parameters: dict[str, float] | None = Field(
103
+ default=None,
104
+ description=("Experiment parameters for any inputs in the protocol."),
105
+ )
106
+ max_backward_jumps: int | None = Field(
107
+ default=None,
108
+ description="Maximum backward jumps allowed (for goto statements)",
109
+ )
110
+ study_id: str | None = Field(default=None, description="Optional study UUID")
111
+ extra_variables: list[str] | None = Field(
112
+ default=None,
113
+ description=(
114
+ "Optional list of extra variables to include in simulation output "
115
+ "(e.g., ['Negative electrode potential [V]', 'Positive electrode "
116
+ "potential [V]']). If provided, these override any extra variables "
117
+ "defined in the experiment template."
118
+ ),
119
+ )
120
+
121
+
122
+ class SimulationResponse(BaseModel):
123
+ """Response model for simulation creation."""
124
+
125
+ simulation_id: str = Field(description="Simulation UUID")
126
+ job_id: str = Field(description="Job UUID")
127
+
128
+
129
+ class SimulationClient:
130
+ """Client for running simulations."""
131
+
132
+ def __init__(self, client: Any):
133
+ self.client = client
134
+
135
+ def protocol(self, config: dict[str, Any]) -> SimulationResponse:
136
+ """Create a single protocol-based simulation.
137
+
138
+ Parameters
139
+ ----------
140
+ config : dict[str, Any]
141
+ Configuration dictionary containing:
142
+ - parameterized_model: quick_model dict, full model dict, or model ID string
143
+ - protocol_experiment: ProtocolExperimentConfig dict with protocol
144
+ and name fields
145
+ - experiment_parameters: Optional ProtocolExperimentParameters dict
146
+ with initial_soc and initial_temperature
147
+ - design_parameters: Optional dict[str, float]
148
+ - max_backward_jumps: Optional int
149
+ - study_id: Optional str
150
+ - extra_variables: Optional list[str] - Extra variables to include
151
+ in simulation output
152
+
153
+ Returns
154
+ -------
155
+ SimulationResponse
156
+ Response containing simulation_id and job_id.
157
+
158
+ Raises
159
+ ------
160
+ ValueError
161
+ If the configuration is invalid.
162
+ """
163
+ endpoint = "/simulations/protocol"
164
+ try:
165
+ validated_config = ProtocolSimulationRequest(**config)
166
+ response_data = self.client.post(
167
+ endpoint, json_payload=validated_config.model_dump(exclude_none=True)
168
+ )
169
+ return SimulationResponse(**response_data)
170
+ except ValidationError as e:
171
+ raise ValueError(f"Invalid protocol simulation configuration: {e}") from e
172
+
173
+ def protocol_batch(self, config: dict[str, Any]) -> list[SimulationResponse]:
174
+ """Create multiple protocol-based simulations using DOE.
175
+
176
+ Parameters
177
+ ----------
178
+ config : dict[str, Any]
179
+ Configuration dictionary containing:
180
+ - parameterized_model: quick_model dict, full model dict, or model ID string
181
+ - protocol_experiment: ProtocolExperimentConfig dict with protocol
182
+ and name fields
183
+ - design_parameters_doe: DesignParametersDOE dict
184
+ - experiment_parameters: Optional ProtocolExperimentParameters dict
185
+ - max_backward_jumps: Optional int
186
+ - study_id: Optional str
187
+ - extra_variables: Optional list[str] - Extra variables to include
188
+ in simulation output
189
+
190
+ Returns
191
+ -------
192
+ list[SimulationResponse]
193
+ List of responses, each containing simulation_id and job_id.
194
+
195
+ Raises
196
+ ------
197
+ ValueError
198
+ If the configuration is invalid.
199
+ """
200
+ endpoint = "/simulations/protocol/batch"
201
+ try:
202
+ validated_config = ProtocolSimulationBatchRequest(**config)
203
+ response_data = self.client.post(
204
+ endpoint, json_payload=validated_config.model_dump(exclude_none=True)
205
+ )
206
+ if not isinstance(response_data, list):
207
+ raise ValueError(
208
+ f"Unexpected response format from {endpoint}: expected a "
209
+ f"list, got {type(response_data).__name__}"
210
+ )
211
+ return [SimulationResponse(**item) for item in response_data]
212
+ except ValidationError as e:
213
+ raise ValueError(
214
+ f"Invalid batch protocol simulation configuration: {e}"
215
+ ) from e
216
+
217
+ def list(self) -> list[dict[str, Any]]:
218
+ """List all simulations for the current user.
219
+
220
+ Returns
221
+ -------
222
+ list[dict[str, Any]]
223
+ List of simulation objects with joined model and experiment data.
224
+ """
225
+ endpoint = "/simulations"
226
+ response_data = self.client.get(endpoint)
227
+ if not isinstance(response_data, list):
228
+ raise ValueError(
229
+ f"Unexpected response format from {endpoint}: expected a list, "
230
+ f"got {type(response_data).__name__}"
231
+ )
232
+ return response_data
233
+
234
+ def get(self, simulation_id: str) -> dict[str, Any]:
235
+ """Get a specific simulation by ID.
236
+
237
+ Parameters
238
+ ----------
239
+ simulation_id : str
240
+ The UUID of the simulation to retrieve.
241
+
242
+ Returns
243
+ -------
244
+ dict[str, Any]
245
+ Simulation object with full joined data including model, experiment,
246
+ and simulation_data (null if not completed).
247
+ """
248
+ endpoint = f"/simulations/{simulation_id}"
249
+ response_data = self.client.get(endpoint)
250
+ return cast(dict[str, Any], response_data)
251
+
252
+ def get_result(self, simulation_id: str) -> dict[str, Any]:
253
+ """Get simulation data/result for a completed simulation.
254
+
255
+ Parameters
256
+ ----------
257
+ simulation_id : str
258
+ The UUID of the simulation.
259
+
260
+ Returns
261
+ -------
262
+ dict[str, Any]
263
+ Simulation data object containing time_series, steps, and metrics.
264
+ Returns 404 if simulation hasn't completed yet.
265
+
266
+ Raises
267
+ ------
268
+ Exception
269
+ If simulation data not found (simulation may not be completed yet).
270
+ The client will raise an appropriate error for 404 responses.
271
+ """
272
+ endpoint = f"/simulations/{simulation_id}/result"
273
+ response_data = self.client.get(endpoint)
274
+ return cast(dict[str, Any], response_data)
275
+
276
+ def wait_for_completion(
277
+ self,
278
+ simulation_id: str | List[str], # List to avoid conflict with list() method
279
+ timeout: int = 60,
280
+ poll_interval: int = 2,
281
+ verbose: bool = True,
282
+ ) -> dict[str, Any] | List[dict[str, Any]]:
283
+ """Wait for simulation(s) to complete by polling until done or timeout.
284
+
285
+ Parameters
286
+ ----------
287
+ simulation_id : str | list[str]
288
+ Single simulation ID or list of simulation IDs to wait for.
289
+ timeout : int, optional
290
+ Maximum time to wait in seconds (default: 60).
291
+ poll_interval : int, optional
292
+ Time between polls in seconds (default: 2).
293
+ verbose : bool, optional
294
+ Whether to print status updates (default: True).
295
+
296
+ Returns
297
+ -------
298
+ dict[str, Any] | list[dict[str, Any]]
299
+ Completed simulation(s). Returns single dict if single ID provided,
300
+ list of dicts if list of IDs provided. Only returns completed
301
+ simulations if timeout is reached.
302
+
303
+ Raises
304
+ ------
305
+ TimeoutError
306
+ If timeout is reached before all simulations complete.
307
+ """
308
+ is_single = isinstance(simulation_id, str)
309
+ if is_single:
310
+ simulation_ids: List[str] = [simulation_id] # type: ignore[list-item]
311
+ else:
312
+ simulation_ids = simulation_id # type: ignore[assignment]
313
+ timeout_delta = timedelta(seconds=timeout)
314
+ start_time = datetime.now()
315
+ completed_simulations: dict[str, dict[str, Any]] = {}
316
+
317
+ if verbose:
318
+ print(f"Polling for {len(simulation_ids)} simulation(s) completion...")
319
+
320
+ while datetime.now() - start_time < timeout_delta:
321
+ completed_count = len(completed_simulations)
322
+ for sim_id in simulation_ids:
323
+ if sim_id in completed_simulations:
324
+ continue
325
+ try:
326
+ simulation = self.get(sim_id)
327
+ if simulation.get("simulation_data") is not None:
328
+ completed_simulations[sim_id] = simulation
329
+ completed_count += 1
330
+ except Exception:
331
+ # Continue polling if there's an error
332
+ pass
333
+
334
+ elapsed = (datetime.now() - start_time).seconds
335
+ if verbose:
336
+ print(
337
+ f" Status: {completed_count}/{len(simulation_ids)} completed "
338
+ f"(elapsed: {elapsed}s)"
339
+ )
340
+
341
+ if len(completed_simulations) == len(simulation_ids):
342
+ if verbose:
343
+ print("All simulations completed!")
344
+ break
345
+
346
+ time.sleep(poll_interval)
347
+ else:
348
+ # Timeout reached
349
+ if verbose:
350
+ print(
351
+ f"Timeout: Only {len(completed_simulations)}/{len(simulation_ids)} "
352
+ f"simulations completed within {timeout} seconds"
353
+ )
354
+ if not completed_simulations:
355
+ raise TimeoutError(f"No simulations completed within {timeout} seconds")
356
+
357
+ # Return results in the same format as input
358
+ if is_single:
359
+ if simulation_ids[0] not in completed_simulations:
360
+ raise TimeoutError(
361
+ f"Simulation {simulation_ids[0]} did not complete within "
362
+ f"{timeout} seconds"
363
+ )
364
+ return completed_simulations[simulation_ids[0]]
365
+ else:
366
+ # Return list of completed simulations in order
367
+ return [
368
+ completed_simulations[sim_id]
369
+ for sim_id in simulation_ids
370
+ if sim_id in completed_simulations
371
+ ]