prefab 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prefab/predict.py ADDED
@@ -0,0 +1,260 @@
1
+ import base64
2
+ import io
3
+ import json
4
+ import os
5
+
6
+ import numpy as np
7
+ import requests
8
+ import toml
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+
12
+ from .geometry import binarize_hard
13
+ from .models import Model
14
+
15
+ BASE_URL = "https://prefab-photonics--predict"
16
+
17
+
18
+ def predict_array(
19
+ device_array: np.ndarray,
20
+ model: Model,
21
+ model_type: str,
22
+ binarize: bool,
23
+ gpu: bool = False,
24
+ ) -> np.ndarray:
25
+ """
26
+ Predicts the output array for a given device array using a specified model.
27
+
28
+ This function sends the device array to a prediction service, which uses a machine
29
+ learning model to predict the outcome of the nanofabrication process. The prediction
30
+ can be performed on a GPU if specified.
31
+
32
+ Parameters
33
+ ----------
34
+ device_array : np.ndarray
35
+ The input device array to be predicted.
36
+ model : Model
37
+ The model to use for prediction.
38
+ model_type : str
39
+ The type of model to use (e.g., 'p', 'c', 's').
40
+ binarize : bool
41
+ Whether to binarize the output.
42
+ gpu : bool, optional
43
+ Whether to use GPU for prediction. Defaults to False.
44
+
45
+ Returns
46
+ -------
47
+ np.ndarray
48
+ The predicted output array.
49
+
50
+ Raises
51
+ ------
52
+ RuntimeError
53
+ If the request to the prediction service fails.
54
+ """
55
+ headers = _prepare_headers()
56
+ predict_data = _prepare_predict_data(device_array, model, model_type, binarize)
57
+ endpoint_url = f"{BASE_URL}-gpu-v1.modal.run" if gpu else f"{BASE_URL}-v1.modal.run"
58
+
59
+ try:
60
+ with requests.post(
61
+ endpoint_url,
62
+ data=json.dumps(predict_data),
63
+ headers=headers,
64
+ stream=True,
65
+ ) as response:
66
+ response.raise_for_status()
67
+ return _process_response(response, model_type, binarize)
68
+ except requests.RequestException as e:
69
+ raise RuntimeError(f"Request failed: {e}") from e
70
+
71
+
72
+ def predict_array_with_grad(
73
+ device_array: np.ndarray, model: Model, model_type: str
74
+ ) -> tuple[np.ndarray, np.ndarray]:
75
+ """
76
+ Predicts the output array and its gradient for a given device array using a
77
+ specified model.
78
+
79
+ This function sends the device array to a prediction service, which uses a machine
80
+ learning model to predict both the outcome and the gradient of the nanofabrication
81
+ process.
82
+
83
+ Parameters
84
+ ----------
85
+ device_array : np.ndarray
86
+ The input device array to be predicted.
87
+ model : Model
88
+ The model to use for prediction.
89
+ model_type : str
90
+ The type of model to use (e.g., 'p', 'c', 's').
91
+
92
+ Returns
93
+ -------
94
+ tuple[np.ndarray, np.ndarray]
95
+ A tuple containing the predicted output array and its gradient.
96
+
97
+ Raises
98
+ ------
99
+ RuntimeError
100
+ If the request to the prediction service fails.
101
+ """
102
+ headers = _prepare_headers()
103
+ predict_data = _prepare_predict_data(device_array, model, model_type, False)
104
+ endpoint_url = f"{BASE_URL}-with-grad-v1.modal.run"
105
+
106
+ response = requests.post(
107
+ endpoint_url, data=json.dumps(predict_data), headers=headers
108
+ )
109
+ prediction_array = _decode_array(response.json()["prediction_array"])
110
+ gradient_array = _decode_array(response.json()["gradient_array"])
111
+ gradient_min = response.json()["gradient_min"]
112
+ gradient_max = response.json()["gradient_max"]
113
+ gradient_range = gradient_max - gradient_min
114
+ gradient_array = gradient_array * gradient_range + gradient_min
115
+
116
+ return (prediction_array, gradient_array)
117
+
118
+
119
+ def _encode_array(array):
120
+ """Encode a numpy array as a PNG image and return the base64 encoded string."""
121
+ image = Image.fromarray(np.uint8(array * 255))
122
+ buffered = io.BytesIO()
123
+ image.save(buffered, format="PNG")
124
+ encoded_png = base64.b64encode(buffered.getvalue()).decode("utf-8")
125
+ return encoded_png
126
+
127
+
128
+ def _decode_array(encoded_png):
129
+ """Decode a base64 encoded PNG image and return a numpy array."""
130
+ binary_data = base64.b64decode(encoded_png)
131
+ image = Image.open(io.BytesIO(binary_data))
132
+ return np.array(image) / 255
133
+
134
+
135
+ def _read_tokens():
136
+ """Read access and refresh tokens from the configuration file."""
137
+ token_file_path = os.path.expanduser("~/.prefab.toml")
138
+ try:
139
+ with open(token_file_path) as file:
140
+ tokens = toml.load(file)
141
+ access_token = tokens.get("access_token")
142
+ refresh_token = tokens.get("refresh_token")
143
+ if not access_token or not refresh_token:
144
+ raise ValueError("Tokens not found in the configuration file.")
145
+ return access_token, refresh_token
146
+ except FileNotFoundError:
147
+ raise FileNotFoundError(
148
+ "Could not validate user.\n"
149
+ "Please update prefab using: pip install --upgrade prefab.\n"
150
+ "Signup/login and generate a new token.\n"
151
+ "See https://www.prefabphotonics.com/docs/guides/quickstart."
152
+ ) from None
153
+
154
+
155
+ def _prepare_headers():
156
+ """Prepare HTTP headers for the request."""
157
+ access_token, refresh_token = _read_tokens()
158
+ return {
159
+ "Authorization": f"Bearer {access_token}",
160
+ "X-Refresh-Token": refresh_token,
161
+ }
162
+
163
+
164
+ def _prepare_predict_data(device_array, model, model_type, binarize):
165
+ """Prepare the data payload for the prediction request."""
166
+ return {
167
+ "device_array": _encode_array(np.squeeze(device_array)),
168
+ "model": model.to_json(),
169
+ "model_type": model_type,
170
+ "binary": binarize,
171
+ }
172
+
173
+
174
+ def _process_response(response, model_type, binarize):
175
+ """Process the streaming response from the prediction request."""
176
+ event_type = None
177
+ model_descriptions = {
178
+ "p": "Prediction",
179
+ "c": "Correction",
180
+ "s": "SEMulate",
181
+ }
182
+ progress_bar = tqdm(
183
+ total=100,
184
+ desc=model_descriptions.get(model_type, "Processing"),
185
+ unit="%",
186
+ colour="green",
187
+ bar_format="{l_bar}{bar:30}{r_bar}{bar:-10b}",
188
+ )
189
+
190
+ for line in response.iter_lines():
191
+ if line:
192
+ decoded_line = line.decode("utf-8").strip()
193
+ if decoded_line.startswith("event:"):
194
+ event_type = decoded_line.split(":", 1)[1].strip()
195
+ elif decoded_line.startswith("data:"):
196
+ data_content = _parse_data_line(decoded_line)
197
+ result = _handle_event(event_type, data_content, progress_bar, binarize)
198
+ if result is not None:
199
+ progress_bar.close()
200
+ return result
201
+ progress_bar.close()
202
+
203
+
204
+ def _parse_data_line(decoded_line):
205
+ """Parse a data line from the response stream."""
206
+ data_line = decoded_line.split(":", 1)[1].strip()
207
+ try:
208
+ return json.loads(data_line)
209
+ except json.JSONDecodeError:
210
+ raise ValueError(f"Failed to decode JSON: {data_line}") from None
211
+
212
+
213
+ def _handle_event(event_type, data_content, progress_bar, binarize):
214
+ """Handle different types of events received from the server."""
215
+ if event_type == "progress":
216
+ _update_progress(progress_bar, data_content)
217
+ elif event_type == "result":
218
+ return _process_result(data_content, binarize)
219
+ elif event_type == "end":
220
+ print("Stream ended.")
221
+ elif event_type == "auth":
222
+ _update_tokens(data_content.get("auth", {}))
223
+ elif event_type == "error":
224
+ raise ValueError(f"{data_content['error']}")
225
+
226
+
227
+ def _update_progress(progress_bar, data_content):
228
+ """Update the progress bar based on the progress event."""
229
+ progress = round(100 * data_content.get("progress", 0))
230
+ progress_bar.update(progress - progress_bar.n)
231
+
232
+
233
+ def _process_result(data_content, binarize):
234
+ """Process the result event and return the prediction."""
235
+ results = [
236
+ _decode_array(data_content[key])
237
+ for key in sorted(data_content.keys())
238
+ if key.startswith("result")
239
+ ]
240
+ if results:
241
+ prediction = np.stack(results, axis=-1)
242
+ if binarize:
243
+ prediction = binarize_hard(prediction)
244
+ return prediction
245
+
246
+
247
+ def _update_tokens(auth_data):
248
+ """Update tokens if new tokens are provided in the auth event."""
249
+ new_access_token = auth_data.get("new_access_token")
250
+ new_refresh_token = auth_data.get("new_refresh_token")
251
+ if new_access_token and new_refresh_token:
252
+ prefab_file_path = os.path.expanduser("~/.prefab.toml")
253
+ with open(prefab_file_path, "w", encoding="utf-8") as toml_file:
254
+ toml.dump(
255
+ {
256
+ "access_token": new_access_token,
257
+ "refresh_token": new_refresh_token,
258
+ },
259
+ toml_file,
260
+ )
prefab/read.py CHANGED
@@ -244,7 +244,7 @@ def from_gdsfactory(
244
244
  If the gdsfactory package is not installed.
245
245
  """
246
246
  try:
247
- import gdsfactory as gf
247
+ import gdsfactory as gf # noqa: F401
248
248
  except ImportError:
249
249
  raise ImportError(
250
250
  "The gdsfactory package is required to use this function; "
@@ -352,6 +352,10 @@ def get_sem_resolution(sem_path: str, sem_resolution_key: str) -> float:
352
352
  Extracts the resolution of a scanning electron microscope (SEM) image from its
353
353
  metadata.
354
354
 
355
+ Note:
356
+ -----
357
+ This function is used internally and may not be useful for most users.
358
+
355
359
  Parameters
356
360
  ----------
357
361
  sem_path : str