prefab 0.5.2__py3-none-any.whl → 1.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prefab/predict.py ADDED
@@ -0,0 +1,337 @@
1
+ """Prediction functions for ndarrays of device geometries."""
2
+
3
+ import base64
4
+ import io
5
+ import json
6
+ import os
7
+
8
+ import numpy as np
9
+ import requests
10
+ import toml
11
+ from autograd import primitive
12
+ from autograd.extend import defvjp
13
+ from PIL import Image
14
+ from tqdm import tqdm
15
+
16
+ from .geometry import binarize_hard
17
+ from .models import Model
18
+
19
+ BASE_ENDPOINT_URL = "https://prefab-photonics--predict"
20
+ ENDPOINT_VERSION = 2
21
+
22
+
23
+ def predict_array(
24
+ device_array: np.ndarray,
25
+ model: Model,
26
+ model_type: str,
27
+ binarize: bool,
28
+ gpu: bool = False,
29
+ ) -> np.ndarray:
30
+ """
31
+ Predict the nanofabrication outcome of a device array using a specified model.
32
+
33
+ This function sends the device array to a serverless prediction service, which uses
34
+ a specified machine learning model to predict the outcome of the nanofabrication
35
+ process. The prediction can be performed on a GPU if specified.
36
+
37
+ Parameters
38
+ ----------
39
+ device_array : np.ndarray
40
+ A 2D array representing the planar geometry of the device. This array undergoes
41
+ various transformations to predict the nanofabrication process.
42
+ model : Model
43
+ The model to use for prediction, representing a specific fabrication process and
44
+ dataset. This model encapsulates details about the fabrication foundry, process,
45
+ material, technology, thickness, and sidewall presence, as defined in
46
+ `models.py`. Each model is associated with a version and dataset that detail its
47
+ creation and the data it was trained on, ensuring the prediction is tailored to
48
+ specific fabrication parameters.
49
+ model_type : str
50
+ The type of model to use (e.g., 'p' for prediction, 'c' for correction, or 's'
51
+ for SEMulate).
52
+ binarize : bool
53
+ If True, the predicted device geometry will be binarized using a threshold
54
+ method. This is useful for converting probabilistic predictions into binary
55
+ geometries.
56
+ gpu : bool
57
+ If True, the prediction will be performed on a GPU. Defaults to False. Note: The
58
+ GPU option has more startup overhead and will take longer for small devices, but
59
+ will be faster for larger devices.
60
+
61
+ Returns
62
+ -------
63
+ np.ndarray
64
+ The predicted output array.
65
+
66
+ Raises
67
+ ------
68
+ RuntimeError
69
+ If the request to the prediction service fails.
70
+ """
71
+ headers = _prepare_headers()
72
+ predict_data = _prepare_predict_data(device_array, model, model_type, binarize)
73
+ endpoint_url = (
74
+ f"{BASE_ENDPOINT_URL}-gpu-v{ENDPOINT_VERSION}.modal.run"
75
+ if gpu
76
+ else f"{BASE_ENDPOINT_URL}-v{ENDPOINT_VERSION}.modal.run"
77
+ )
78
+
79
+ try:
80
+ with requests.post(
81
+ endpoint_url,
82
+ data=json.dumps(predict_data),
83
+ headers=headers,
84
+ stream=True,
85
+ ) as response:
86
+ response.raise_for_status()
87
+ result = _process_response(response, model_type, binarize)
88
+ if result is None:
89
+ raise RuntimeError("No prediction result received.")
90
+ return result
91
+ except requests.RequestException as e:
92
+ raise RuntimeError(f"Request failed: {e}") from e
93
+
94
+
95
+ def _predict_array_with_grad(
96
+ device_array: np.ndarray, model: Model
97
+ ) -> tuple[np.ndarray, np.ndarray]:
98
+ """
99
+ Predict the nanofabrication outcome of a device array and compute its gradient.
100
+
101
+ This function predicts the outcome of the nanofabrication process for a given
102
+ device array using a specified model. It also computes the gradient of the
103
+ prediction with respect to the input device array.
104
+
105
+ Notes
106
+ -----
107
+ This function is currently not used in the main `predict_array` function as
108
+ the main `predict_array` function (e.g., GPU support and progress bar) for now.
109
+
110
+ Parameters
111
+ ----------
112
+ device_array : np.ndarray
113
+ A 2D array representing the planar geometry of the device.
114
+ model : Model
115
+ The model to use for prediction, representing a specific fabrication process.
116
+
117
+ Returns
118
+ -------
119
+ tuple[np.ndarray, np.ndarray]
120
+ The predicted output array and gradient array.
121
+ """
122
+ headers = _prepare_headers()
123
+ predict_data = _prepare_predict_data(device_array, model, "p", False)
124
+ endpoint_url = f"{BASE_ENDPOINT_URL}-with-grad-v{ENDPOINT_VERSION}.modal.run"
125
+
126
+ response = requests.post(
127
+ endpoint_url, data=json.dumps(predict_data), headers=headers
128
+ )
129
+ prediction_array = _decode_array(response.json()["prediction_array"])
130
+ gradient_array = _decode_array(response.json()["gradient_array"])
131
+ gradient_min = response.json()["gradient_min"]
132
+ gradient_max = response.json()["gradient_max"]
133
+ gradient_range = gradient_max - gradient_min
134
+ gradient_array = gradient_array * gradient_range + gradient_min
135
+ return (prediction_array, gradient_array)
136
+
137
+
138
+ @primitive
139
+ def predict_array_with_grad(device_array: np.ndarray, model: Model) -> np.ndarray:
140
+ """
141
+ Predict the nanofabrication outcome of a device array and compute its gradient.
142
+
143
+ This function predicts the outcome of the nanofabrication process for a given
144
+ device array using a specified model. It also computes the gradient of the
145
+ prediction with respect to the input device array, making it suitable for use in
146
+ automatic differentiation applications (e.g., autograd).
147
+
148
+ Parameters
149
+ ----------
150
+ device_array : np.ndarray
151
+ A 2D array representing the planar geometry of the device.
152
+ model : Model
153
+ The model to use for prediction, representing a specific fabrication process.
154
+
155
+ Returns
156
+ -------
157
+ np.ndarray
158
+ The predicted output array.
159
+ """
160
+ prediction_array, gradient_array = _predict_array_with_grad(
161
+ device_array=device_array, model=model
162
+ )
163
+ predict_array_with_grad.gradient_array = gradient_array # type: ignore
164
+ return prediction_array
165
+
166
+
167
+ def predict_array_with_grad_vjp(ans: np.ndarray, device_array: np.ndarray, *args):
168
+ """
169
+ Define the vector-Jacobian product (VJP) for the prediction function.
170
+
171
+ Parameters
172
+ ----------
173
+ ans : np.ndarray
174
+ The output of the `predict_array_with_grad` function.
175
+ device_array : np.ndarray
176
+ The input device array for which the gradient is computed.
177
+ *args :
178
+ Additional arguments that aren't used in the VJP computation.
179
+
180
+ Returns
181
+ -------
182
+ function
183
+ A function that computes the VJP given an upstream gradient `g`.
184
+ """
185
+ grad_x = predict_array_with_grad.gradient_array # type: ignore
186
+
187
+ def vjp(g: np.ndarray) -> np.ndarray:
188
+ return g * grad_x
189
+
190
+ return vjp
191
+
192
+
193
+ defvjp(predict_array_with_grad, predict_array_with_grad_vjp)
194
+
195
+
196
+ def _encode_array(array):
197
+ """Encode an ndarray as a base64 encoded image for transmission."""
198
+ image = Image.fromarray(np.uint8(array * 255))
199
+ buffered = io.BytesIO()
200
+ image.save(buffered, format="PNG")
201
+ encoded_png = base64.b64encode(buffered.getvalue()).decode("utf-8")
202
+ return encoded_png
203
+
204
+
205
+ def _decode_array(encoded_png):
206
+ """Decode a base64 encoded image and return an ndarray."""
207
+ binary_data = base64.b64decode(encoded_png)
208
+ image = Image.open(io.BytesIO(binary_data))
209
+ return np.array(image) / 255
210
+
211
+
212
+ def _read_tokens():
213
+ """Read access and refresh tokens from the configuration file."""
214
+ token_file_path = os.path.expanduser("~/.prefab.toml")
215
+ try:
216
+ with open(token_file_path) as file:
217
+ tokens = toml.load(file)
218
+ access_token = tokens.get("access_token")
219
+ refresh_token = tokens.get("refresh_token")
220
+ if not access_token or not refresh_token:
221
+ raise ValueError("Tokens not found in the configuration file.")
222
+ return access_token, refresh_token
223
+ except FileNotFoundError:
224
+ raise FileNotFoundError(
225
+ "Could not validate user.\n"
226
+ "Please update prefab using: pip install --upgrade prefab.\n"
227
+ "Signup/login and generate a new token.\n"
228
+ "See https://docs.prefabphotonics.com/."
229
+ ) from None
230
+
231
+
232
+ def _prepare_headers():
233
+ """Prepare HTTP headers for the request."""
234
+ access_token, refresh_token = _read_tokens()
235
+ return {
236
+ "Authorization": f"Bearer {access_token}",
237
+ "X-Refresh-Token": refresh_token,
238
+ }
239
+
240
+
241
+ def _prepare_predict_data(device_array, model, model_type, binarize):
242
+ """Prepare the data payload for the prediction request."""
243
+ return {
244
+ "device_array": _encode_array(np.squeeze(device_array)),
245
+ "model": model.to_json(),
246
+ "model_type": model_type,
247
+ "binary": binarize,
248
+ }
249
+
250
+
251
+ def _process_response(response, model_type, binarize):
252
+ """Process the streaming response from the prediction request."""
253
+ event_type = None
254
+ model_descriptions = {
255
+ "p": "Prediction",
256
+ "c": "Correction",
257
+ "s": "SEMulate",
258
+ }
259
+ progress_bar = tqdm(
260
+ total=100,
261
+ desc=model_descriptions.get(model_type, "Processing"),
262
+ unit="%",
263
+ colour="green",
264
+ bar_format="{l_bar}{bar:30}{r_bar}{bar:-10b}",
265
+ )
266
+
267
+ for line in response.iter_lines():
268
+ if line:
269
+ decoded_line = line.decode("utf-8").strip()
270
+ if decoded_line.startswith("event:"):
271
+ event_type = decoded_line.split(":", 1)[1].strip()
272
+ elif decoded_line.startswith("data:"):
273
+ data_content = _parse_data_line(decoded_line)
274
+ result = _handle_event(event_type, data_content, progress_bar, binarize)
275
+ if result is not None:
276
+ progress_bar.close()
277
+ return result
278
+ progress_bar.close()
279
+
280
+
281
+ def _parse_data_line(decoded_line):
282
+ """Parse a data line from the response stream."""
283
+ data_line = decoded_line.split(":", 1)[1].strip()
284
+ try:
285
+ return json.loads(data_line)
286
+ except json.JSONDecodeError:
287
+ raise ValueError(f"Failed to decode JSON: {data_line}") from None
288
+
289
+
290
+ def _handle_event(event_type, data_content, progress_bar, binarize):
291
+ """Handle different types of events received from the server."""
292
+ if event_type == "progress":
293
+ _update_progress(progress_bar, data_content)
294
+ elif event_type == "result":
295
+ return _process_result(data_content, binarize)
296
+ elif event_type == "end":
297
+ print("Stream ended.")
298
+ elif event_type == "auth":
299
+ _update_tokens(data_content.get("auth", {}))
300
+ elif event_type == "error":
301
+ raise ValueError(f"{data_content['error']}")
302
+
303
+
304
+ def _update_progress(progress_bar, data_content):
305
+ """Update the progress bar based on the progress event."""
306
+ progress = round(100 * data_content.get("progress", 0))
307
+ progress_bar.update(progress - progress_bar.n)
308
+
309
+
310
+ def _process_result(data_content, binarize):
311
+ """Process the result event and return the prediction."""
312
+ results = [
313
+ _decode_array(data_content[key])
314
+ for key in sorted(data_content.keys())
315
+ if key.startswith("result")
316
+ ]
317
+ if results:
318
+ prediction = np.stack(results, axis=-1)
319
+ if binarize:
320
+ prediction = binarize_hard(prediction)
321
+ return prediction
322
+
323
+
324
+ def _update_tokens(auth_data):
325
+ """Update tokens if new tokens are provided in the auth event."""
326
+ new_access_token = auth_data.get("new_access_token")
327
+ new_refresh_token = auth_data.get("new_refresh_token")
328
+ if new_access_token and new_refresh_token:
329
+ prefab_file_path = os.path.expanduser("~/.prefab.toml")
330
+ with open(prefab_file_path, "w", encoding="utf-8") as toml_file:
331
+ toml.dump(
332
+ {
333
+ "access_token": new_access_token,
334
+ "refresh_token": new_refresh_token,
335
+ },
336
+ toml_file,
337
+ )