prefab 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prefab/__init__.py +3 -2
- prefab/compare.py +17 -0
- prefab/device.py +40 -177
- prefab/geometry.py +43 -1
- prefab/models.py +1 -1
- prefab/predict.py +260 -0
- prefab/read.py +5 -1
- prefab/shapes.py +208 -201
- {prefab-1.1.1.dist-info → prefab-1.1.3.dist-info}/METADATA +12 -12
- prefab-1.1.3.dist-info/RECORD +13 -0
- prefab-1.1.1.dist-info/RECORD +0 -12
- {prefab-1.1.1.dist-info → prefab-1.1.3.dist-info}/WHEEL +0 -0
- {prefab-1.1.1.dist-info → prefab-1.1.3.dist-info}/licenses/LICENSE +0 -0
prefab/predict.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import io
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import requests
|
|
8
|
+
import toml
|
|
9
|
+
from PIL import Image
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
from .geometry import binarize_hard
|
|
13
|
+
from .models import Model
|
|
14
|
+
|
|
15
|
+
BASE_URL = "https://prefab-photonics--predict"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def predict_array(
|
|
19
|
+
device_array: np.ndarray,
|
|
20
|
+
model: Model,
|
|
21
|
+
model_type: str,
|
|
22
|
+
binarize: bool,
|
|
23
|
+
gpu: bool = False,
|
|
24
|
+
) -> np.ndarray:
|
|
25
|
+
"""
|
|
26
|
+
Predicts the output array for a given device array using a specified model.
|
|
27
|
+
|
|
28
|
+
This function sends the device array to a prediction service, which uses a machine
|
|
29
|
+
learning model to predict the outcome of the nanofabrication process. The prediction
|
|
30
|
+
can be performed on a GPU if specified.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
device_array : np.ndarray
|
|
35
|
+
The input device array to be predicted.
|
|
36
|
+
model : Model
|
|
37
|
+
The model to use for prediction.
|
|
38
|
+
model_type : str
|
|
39
|
+
The type of model to use (e.g., 'p', 'c', 's').
|
|
40
|
+
binarize : bool
|
|
41
|
+
Whether to binarize the output.
|
|
42
|
+
gpu : bool, optional
|
|
43
|
+
Whether to use GPU for prediction. Defaults to False.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
np.ndarray
|
|
48
|
+
The predicted output array.
|
|
49
|
+
|
|
50
|
+
Raises
|
|
51
|
+
------
|
|
52
|
+
RuntimeError
|
|
53
|
+
If the request to the prediction service fails.
|
|
54
|
+
"""
|
|
55
|
+
headers = _prepare_headers()
|
|
56
|
+
predict_data = _prepare_predict_data(device_array, model, model_type, binarize)
|
|
57
|
+
endpoint_url = f"{BASE_URL}-gpu-v1.modal.run" if gpu else f"{BASE_URL}-v1.modal.run"
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
with requests.post(
|
|
61
|
+
endpoint_url,
|
|
62
|
+
data=json.dumps(predict_data),
|
|
63
|
+
headers=headers,
|
|
64
|
+
stream=True,
|
|
65
|
+
) as response:
|
|
66
|
+
response.raise_for_status()
|
|
67
|
+
return _process_response(response, model_type, binarize)
|
|
68
|
+
except requests.RequestException as e:
|
|
69
|
+
raise RuntimeError(f"Request failed: {e}") from e
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def predict_array_with_grad(
|
|
73
|
+
device_array: np.ndarray, model: Model, model_type: str
|
|
74
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
75
|
+
"""
|
|
76
|
+
Predicts the output array and its gradient for a given device array using a
|
|
77
|
+
specified model.
|
|
78
|
+
|
|
79
|
+
This function sends the device array to a prediction service, which uses a machine
|
|
80
|
+
learning model to predict both the outcome and the gradient of the nanofabrication
|
|
81
|
+
process.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
device_array : np.ndarray
|
|
86
|
+
The input device array to be predicted.
|
|
87
|
+
model : Model
|
|
88
|
+
The model to use for prediction.
|
|
89
|
+
model_type : str
|
|
90
|
+
The type of model to use (e.g., 'p', 'c', 's').
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
tuple[np.ndarray, np.ndarray]
|
|
95
|
+
A tuple containing the predicted output array and its gradient.
|
|
96
|
+
|
|
97
|
+
Raises
|
|
98
|
+
------
|
|
99
|
+
RuntimeError
|
|
100
|
+
If the request to the prediction service fails.
|
|
101
|
+
"""
|
|
102
|
+
headers = _prepare_headers()
|
|
103
|
+
predict_data = _prepare_predict_data(device_array, model, model_type, False)
|
|
104
|
+
endpoint_url = f"{BASE_URL}-with-grad-v1.modal.run"
|
|
105
|
+
|
|
106
|
+
response = requests.post(
|
|
107
|
+
endpoint_url, data=json.dumps(predict_data), headers=headers
|
|
108
|
+
)
|
|
109
|
+
prediction_array = _decode_array(response.json()["prediction_array"])
|
|
110
|
+
gradient_array = _decode_array(response.json()["gradient_array"])
|
|
111
|
+
gradient_min = response.json()["gradient_min"]
|
|
112
|
+
gradient_max = response.json()["gradient_max"]
|
|
113
|
+
gradient_range = gradient_max - gradient_min
|
|
114
|
+
gradient_array = gradient_array * gradient_range + gradient_min
|
|
115
|
+
|
|
116
|
+
return (prediction_array, gradient_array)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _encode_array(array):
|
|
120
|
+
"""Encode a numpy array as a PNG image and return the base64 encoded string."""
|
|
121
|
+
image = Image.fromarray(np.uint8(array * 255))
|
|
122
|
+
buffered = io.BytesIO()
|
|
123
|
+
image.save(buffered, format="PNG")
|
|
124
|
+
encoded_png = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
125
|
+
return encoded_png
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _decode_array(encoded_png):
|
|
129
|
+
"""Decode a base64 encoded PNG image and return a numpy array."""
|
|
130
|
+
binary_data = base64.b64decode(encoded_png)
|
|
131
|
+
image = Image.open(io.BytesIO(binary_data))
|
|
132
|
+
return np.array(image) / 255
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _read_tokens():
|
|
136
|
+
"""Read access and refresh tokens from the configuration file."""
|
|
137
|
+
token_file_path = os.path.expanduser("~/.prefab.toml")
|
|
138
|
+
try:
|
|
139
|
+
with open(token_file_path) as file:
|
|
140
|
+
tokens = toml.load(file)
|
|
141
|
+
access_token = tokens.get("access_token")
|
|
142
|
+
refresh_token = tokens.get("refresh_token")
|
|
143
|
+
if not access_token or not refresh_token:
|
|
144
|
+
raise ValueError("Tokens not found in the configuration file.")
|
|
145
|
+
return access_token, refresh_token
|
|
146
|
+
except FileNotFoundError:
|
|
147
|
+
raise FileNotFoundError(
|
|
148
|
+
"Could not validate user.\n"
|
|
149
|
+
"Please update prefab using: pip install --upgrade prefab.\n"
|
|
150
|
+
"Signup/login and generate a new token.\n"
|
|
151
|
+
"See https://www.prefabphotonics.com/docs/guides/quickstart."
|
|
152
|
+
) from None
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _prepare_headers():
|
|
156
|
+
"""Prepare HTTP headers for the request."""
|
|
157
|
+
access_token, refresh_token = _read_tokens()
|
|
158
|
+
return {
|
|
159
|
+
"Authorization": f"Bearer {access_token}",
|
|
160
|
+
"X-Refresh-Token": refresh_token,
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _prepare_predict_data(device_array, model, model_type, binarize):
|
|
165
|
+
"""Prepare the data payload for the prediction request."""
|
|
166
|
+
return {
|
|
167
|
+
"device_array": _encode_array(np.squeeze(device_array)),
|
|
168
|
+
"model": model.to_json(),
|
|
169
|
+
"model_type": model_type,
|
|
170
|
+
"binary": binarize,
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _process_response(response, model_type, binarize):
|
|
175
|
+
"""Process the streaming response from the prediction request."""
|
|
176
|
+
event_type = None
|
|
177
|
+
model_descriptions = {
|
|
178
|
+
"p": "Prediction",
|
|
179
|
+
"c": "Correction",
|
|
180
|
+
"s": "SEMulate",
|
|
181
|
+
}
|
|
182
|
+
progress_bar = tqdm(
|
|
183
|
+
total=100,
|
|
184
|
+
desc=model_descriptions.get(model_type, "Processing"),
|
|
185
|
+
unit="%",
|
|
186
|
+
colour="green",
|
|
187
|
+
bar_format="{l_bar}{bar:30}{r_bar}{bar:-10b}",
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
for line in response.iter_lines():
|
|
191
|
+
if line:
|
|
192
|
+
decoded_line = line.decode("utf-8").strip()
|
|
193
|
+
if decoded_line.startswith("event:"):
|
|
194
|
+
event_type = decoded_line.split(":", 1)[1].strip()
|
|
195
|
+
elif decoded_line.startswith("data:"):
|
|
196
|
+
data_content = _parse_data_line(decoded_line)
|
|
197
|
+
result = _handle_event(event_type, data_content, progress_bar, binarize)
|
|
198
|
+
if result is not None:
|
|
199
|
+
progress_bar.close()
|
|
200
|
+
return result
|
|
201
|
+
progress_bar.close()
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _parse_data_line(decoded_line):
|
|
205
|
+
"""Parse a data line from the response stream."""
|
|
206
|
+
data_line = decoded_line.split(":", 1)[1].strip()
|
|
207
|
+
try:
|
|
208
|
+
return json.loads(data_line)
|
|
209
|
+
except json.JSONDecodeError:
|
|
210
|
+
raise ValueError(f"Failed to decode JSON: {data_line}") from None
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _handle_event(event_type, data_content, progress_bar, binarize):
|
|
214
|
+
"""Handle different types of events received from the server."""
|
|
215
|
+
if event_type == "progress":
|
|
216
|
+
_update_progress(progress_bar, data_content)
|
|
217
|
+
elif event_type == "result":
|
|
218
|
+
return _process_result(data_content, binarize)
|
|
219
|
+
elif event_type == "end":
|
|
220
|
+
print("Stream ended.")
|
|
221
|
+
elif event_type == "auth":
|
|
222
|
+
_update_tokens(data_content.get("auth", {}))
|
|
223
|
+
elif event_type == "error":
|
|
224
|
+
raise ValueError(f"{data_content['error']}")
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _update_progress(progress_bar, data_content):
|
|
228
|
+
"""Update the progress bar based on the progress event."""
|
|
229
|
+
progress = round(100 * data_content.get("progress", 0))
|
|
230
|
+
progress_bar.update(progress - progress_bar.n)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _process_result(data_content, binarize):
|
|
234
|
+
"""Process the result event and return the prediction."""
|
|
235
|
+
results = [
|
|
236
|
+
_decode_array(data_content[key])
|
|
237
|
+
for key in sorted(data_content.keys())
|
|
238
|
+
if key.startswith("result")
|
|
239
|
+
]
|
|
240
|
+
if results:
|
|
241
|
+
prediction = np.stack(results, axis=-1)
|
|
242
|
+
if binarize:
|
|
243
|
+
prediction = binarize_hard(prediction)
|
|
244
|
+
return prediction
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _update_tokens(auth_data):
|
|
248
|
+
"""Update tokens if new tokens are provided in the auth event."""
|
|
249
|
+
new_access_token = auth_data.get("new_access_token")
|
|
250
|
+
new_refresh_token = auth_data.get("new_refresh_token")
|
|
251
|
+
if new_access_token and new_refresh_token:
|
|
252
|
+
prefab_file_path = os.path.expanduser("~/.prefab.toml")
|
|
253
|
+
with open(prefab_file_path, "w", encoding="utf-8") as toml_file:
|
|
254
|
+
toml.dump(
|
|
255
|
+
{
|
|
256
|
+
"access_token": new_access_token,
|
|
257
|
+
"refresh_token": new_refresh_token,
|
|
258
|
+
},
|
|
259
|
+
toml_file,
|
|
260
|
+
)
|
prefab/read.py
CHANGED
|
@@ -244,7 +244,7 @@ def from_gdsfactory(
|
|
|
244
244
|
If the gdsfactory package is not installed.
|
|
245
245
|
"""
|
|
246
246
|
try:
|
|
247
|
-
import gdsfactory as gf
|
|
247
|
+
import gdsfactory as gf # noqa: F401
|
|
248
248
|
except ImportError:
|
|
249
249
|
raise ImportError(
|
|
250
250
|
"The gdsfactory package is required to use this function; "
|
|
@@ -352,6 +352,10 @@ def get_sem_resolution(sem_path: str, sem_resolution_key: str) -> float:
|
|
|
352
352
|
Extracts the resolution of a scanning electron microscope (SEM) image from its
|
|
353
353
|
metadata.
|
|
354
354
|
|
|
355
|
+
Note:
|
|
356
|
+
-----
|
|
357
|
+
This function is used internally and may not be useful for most users.
|
|
358
|
+
|
|
355
359
|
Parameters
|
|
356
360
|
----------
|
|
357
361
|
sem_path : str
|