prefab 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prefab/predict.py CHANGED
@@ -1,37 +1,134 @@
1
- """Prediction functions for ndarrays of device geometries."""
1
+ """
2
+ Serverless prediction interface for nanofabrication modeling.
3
+
4
+ This module provides functions for predicting nanofabrication outcomes using machine
5
+ learning models hosted on a serverless platform. It supports multiple input formats
6
+ (ndarrays, polygons, GDSII files) and model types (prediction, correction,
7
+ segmentation). Gradient computation is available for inverse design applications
8
+ using automatic differentiation.
9
+ """
2
10
 
3
11
  import base64
4
12
  import io
5
13
  import json
6
14
  import os
15
+ from typing import Any
7
16
 
8
17
  import gdstk
9
18
  import numpy as np
19
+ import numpy.typing as npt
10
20
  import requests
11
21
  import toml
12
22
  from autograd import primitive
13
23
  from autograd.extend import defvjp
14
24
  from PIL import Image
15
- from tqdm import tqdm
16
25
 
17
26
  from .geometry import binarize_hard
18
27
  from .models import Model
19
28
 
20
29
  BASE_ENDPOINT_URL = "https://prefab-photonics--predict"
21
- ENDPOINT_VERSION = "2"
30
+ ENDPOINT_VERSION = "3"
31
+
32
+
33
+ def predict_array(
34
+ device_array: npt.NDArray[Any],
35
+ model: Model,
36
+ model_type: str,
37
+ binarize: bool,
38
+ ) -> npt.NDArray[Any]:
39
+ """
40
+ Predict the nanofabrication outcome of a device array using a specified model.
41
+
42
+ This function sends the device array to a serverless prediction service, which uses
43
+ a specified machine learning model to predict the outcome of the nanofabrication
44
+ process.
45
+
46
+ Parameters
47
+ ----------
48
+ device_array : np.ndarray
49
+ A 2D array representing the planar geometry of the device. This array undergoes
50
+ various transformations to predict the nanofabrication process.
51
+ model : Model
52
+ The model to use for prediction, representing a specific fabrication process and
53
+ dataset. This model encapsulates details about the fabrication foundry and
54
+ process, as defined in `models.py`. Each model is associated with a version and
55
+ dataset that detail its creation and the data it was trained on, ensuring the
56
+ prediction is tailored to specific fabrication parameters.
57
+ model_type : str
58
+ The type of model to use (e.g., 'p' for prediction or 'c' for correction).
59
+ binarize : bool
60
+ If True, the predicted device geometry will be binarized using a threshold
61
+ method. This is useful for converting probabilistic predictions into binary
62
+ geometries.
63
+
64
+ Returns
65
+ -------
66
+ np.ndarray
67
+ The predicted output array. For single-level predictions, returns shape
68
+ (h, w, 1). For multi-level predictions, returns shape (h, w, n) where n is the
69
+ number of levels.
70
+
71
+ Raises
72
+ ------
73
+ RuntimeError
74
+ If the request to the prediction service fails.
75
+ ValueError
76
+ If the server returns an error or invalid response.
77
+ """
78
+ endpoint_url = f"{BASE_ENDPOINT_URL}-v{ENDPOINT_VERSION}.modal.run"
79
+ predict_data = {
80
+ "device_array": _encode_array(np.squeeze(device_array)),
81
+ "model": model.to_json(),
82
+ "model_type": model_type,
83
+ }
84
+ headers = _prepare_headers()
85
+
86
+ try:
87
+ response = requests.post(
88
+ url=endpoint_url, data=json.dumps(predict_data), headers=headers
89
+ )
90
+ response.raise_for_status()
91
+
92
+ if not response.content:
93
+ raise ValueError("Empty response received from server")
94
+
95
+ response_data = response.json()
96
+
97
+ if "error" in response_data:
98
+ raise ValueError(f"Prediction error: {response_data['error']}")
99
+
100
+ results = response_data["results"]
101
+ result_arrays = [
102
+ _decode_array(results[key])
103
+ for key in sorted(results.keys())
104
+ if key.startswith("result")
105
+ ]
106
+
107
+ prediction_array = np.stack(result_arrays, axis=-1)
108
+
109
+ if binarize:
110
+ prediction_array = binarize_hard(prediction_array)
111
+
112
+ return prediction_array
113
+
114
+ except requests.exceptions.RequestException as e:
115
+ raise RuntimeError(f"Request failed: {e}") from e
116
+ except json.JSONDecodeError as e:
117
+ raise ValueError(f"JSON decode error: {e}") from e
22
118
 
23
119
 
24
120
  def _predict_poly(
25
- polygon_points: list,
121
+ polygon_points: list[Any],
26
122
  model: Model,
27
123
  model_type: str,
28
124
  eta: float = 0.5,
29
- ) -> list:
125
+ ) -> list[Any]:
30
126
  """
31
- Predict the nanofabrication outcome for a list of polygons.
127
+ Predict the nanofabrication outcome for a geometry (list of polygons).
32
128
 
33
- This function sends polygon data to the server, which uses a specified machine
34
- learning model to predict the outcome of the nanofabrication process.
129
+ This function sends the device array to a serverless prediction service, which uses
130
+ a specified machine learning model to predict the outcome of the nanofabrication
131
+ process.
35
132
 
36
133
  Parameters
37
134
  ----------
@@ -39,13 +136,12 @@ def _predict_poly(
39
136
  List of polygon points, where each polygon is a list of [x, y] coordinates.
40
137
  model : Model
41
138
  The model to use for prediction, representing a specific fabrication process and
42
- dataset. This model encapsulates details about the fabrication foundry, process,
43
- material, technology, thickness, and sidewall presence, as defined in
44
- `models.py`. Each model is associated with a version and dataset that detail its
45
- creation and the data it was trained on, ensuring the prediction is tailored to
46
- specific fabrication parameters.
139
+ dataset. This model encapsulates details about the fabrication foundry and
140
+ process, as defined in `models.py`. Each model is associated with a version and
141
+ dataset that detail its creation and the data it was trained on, ensuring the
142
+ prediction is tailored to specific fabrication parameters.
47
143
  model_type : str
48
- The type of model to use ('p' for prediction, 'c' for correction).
144
+ The type of model to use (e.g., 'p' for prediction or 'c' for correction).
49
145
  eta : float
50
146
  The threshold value for binarization. Defaults to 0.5. Because intermediate
51
147
  values cannot be preserved in the polygon data, the predicted polygons are
@@ -54,17 +150,15 @@ def _predict_poly(
54
150
  Returns
55
151
  -------
56
152
  list
57
- List of predicted polygon points with channel information. Each polygon is a
58
- dict with 'points' (list of coordinates) and 'channel' (int) keys.
153
+ List of predicted polygon points with level information. Each polygon is a dict
154
+ with 'points' (list of coordinates) and 'level' (int) keys.
59
155
 
60
156
  Raises
61
157
  ------
62
- ValueError
63
- If the server returns an error or empty response.
64
- requests.exceptions.RequestException
158
+ RuntimeError
65
159
  If the request to the prediction service fails.
66
- json.JSONDecodeError
67
- If the response cannot be parsed as JSON.
160
+ ValueError
161
+ If the server returns an error or invalid response.
68
162
  """
69
163
  predict_data = {
70
164
  "polygons": polygon_points,
@@ -99,11 +193,89 @@ def _predict_poly(
99
193
  return []
100
194
 
101
195
  except requests.exceptions.RequestException as e:
102
- print(f"Request failed: {str(e)}")
103
- raise
196
+ raise RuntimeError(f"Request failed: {e}") from e
104
197
  except json.JSONDecodeError as e:
105
- print(f"JSON decode error: {str(e)}")
106
- raise
198
+ raise ValueError(f"JSON decode error: {e}") from e
199
+
200
+
201
+ def predict_gds(
202
+ gds_path: str,
203
+ cell_name: str,
204
+ model: Model,
205
+ model_type: str,
206
+ gds_layer: tuple[int, int] = (1, 0),
207
+ eta: float = 0.5,
208
+ output_path: str | None = None,
209
+ ) -> None:
210
+ """
211
+ Predict the nanofabrication outcome for a GDS file and cell.
212
+
213
+ This function loads a GDS file, extracts the specified cell, and predicts the
214
+ nanofabrication outcome using the specified model. The predicted cell is
215
+ automatically added to the original GDS library and the file is written to the
216
+ specified output path (or overwrites the original if no output path is provided).
217
+
218
+ Parameters
219
+ ----------
220
+ gds_path : str
221
+ The file path to the GDS file.
222
+ cell_name : str
223
+ The name of the cell within the GDS file to predict.
224
+ model : Model
225
+ The model to use for prediction, representing a specific fabrication process and
226
+ dataset. This model encapsulates details about the fabrication foundry and
227
+ process, as defined in `models.py`. Each model is associated with a version and
228
+ dataset that detail its creation and the data it was trained on, ensuring the
229
+ prediction is tailored to specific fabrication parameters.
230
+ model_type : str
231
+ The type of model to use (e.g., 'p' for prediction or 'c' for correction).
232
+ gds_layer : tuple[int, int]
233
+ The layer and datatype to use within the GDS file. Defaults to (1, 0).
234
+ eta : float
235
+ The threshold value for binarization. Defaults to 0.5. Because intermediate
236
+ values cannot be preserved in the polygon data, the predicted polygons are
237
+ binarized using a threshold value of eta.
238
+ output_path : str, optional
239
+ The file path where the updated GDS file will be written. If None, the
240
+ original file will be overwritten. Defaults to None.
241
+
242
+ Raises
243
+ ------
244
+ RuntimeError
245
+ If the request to the prediction service fails.
246
+ ValueError
247
+ If the GDS file cannot be read, the specified cell is not found, or the server
248
+ returns an error or invalid response.
249
+ """
250
+ gdstk_library = gdstk.read_gds(gds_path)
251
+ cells = [
252
+ cell
253
+ for cell in gdstk_library.cells
254
+ if isinstance(cell, gdstk.Cell) and cell.name == cell_name
255
+ ]
256
+ if not cells:
257
+ raise ValueError(f"Cell '{cell_name}' not found in GDS file")
258
+ gdstk_cell = cells[0]
259
+
260
+ predicted_cell = predict_gdstk(
261
+ gdstk_cell=gdstk_cell,
262
+ model=model,
263
+ model_type=model_type,
264
+ gds_layer=gds_layer,
265
+ eta=eta,
266
+ )
267
+
268
+ base_name = predicted_cell.name
269
+ counter = 1
270
+ while predicted_cell.name in [cell.name for cell in gdstk_library.cells]:
271
+ predicted_cell.name = f"{base_name}_{counter}"
272
+ counter += 1
273
+
274
+ gdstk_library.add(predicted_cell)
275
+
276
+ write_path = output_path if output_path is not None else gds_path
277
+ print(f"Writing to {write_path}")
278
+ gdstk_library.write_gds(write_path, max_points=8190)
107
279
 
108
280
 
109
281
  def predict_gdstk(
@@ -116,8 +288,8 @@ def predict_gdstk(
116
288
  """
117
289
  Predict the nanofabrication outcome of a gdstk cell using a specified model.
118
290
 
119
- This function extracts polygons from a gdstk cell, sends them to the prediction
120
- server, and returns a new cell containing the predicted polygons.
291
+ This function extracts polygons from a gdstk cell, sends them to a serverless
292
+ prediction service, and returns a new cell containing the predicted polygons.
121
293
 
122
294
  Parameters
123
295
  ----------
@@ -125,13 +297,12 @@ def predict_gdstk(
125
297
  The gdstk.Cell object containing polygons to predict.
126
298
  model : Model
127
299
  The model to use for prediction, representing a specific fabrication process and
128
- dataset. This model encapsulates details about the fabrication foundry, process,
129
- material, technology, thickness, and sidewall presence, as defined in
130
- `models.py`. Each model is associated with a version and dataset that detail its
131
- creation and the data it was trained on, ensuring the prediction is tailored to
132
- specific fabrication parameters.
300
+ dataset. This model encapsulates details about the fabrication foundry and
301
+ process, as defined in `models.py`. Each model is associated with a version and
302
+ dataset that detail its creation and the data it was trained on, ensuring the
303
+ prediction is tailored to specific fabrication parameters.
133
304
  model_type : str
134
- The type of model to use ('p' for prediction, 'c' for correction).
305
+ The type of model to use (e.g., 'p' for prediction or 'c' for correction).
135
306
  gds_layer : tuple[int, int]
136
307
  The layer and datatype to use within the GDSTK cell. Defaults to (1, 0).
137
308
  eta : float
@@ -149,14 +320,17 @@ def predict_gdstk(
149
320
 
150
321
  Raises
151
322
  ------
323
+ RuntimeError
324
+ If the request to the prediction service fails.
152
325
  ValueError
153
- If no polygons are found in the specified layer.
326
+ If no polygons are found in the specified layer, or the server returns an error
327
+ or invalid response.
154
328
  """
155
329
  polygons = gdstk_cell.get_polygons(layer=gds_layer[0], datatype=gds_layer[1])
156
330
  if not polygons:
157
331
  raise ValueError("No polygons found in the specified layer")
158
332
 
159
- polygon_points = [polygon.points.tolist() for polygon in polygons]
333
+ polygon_points = [polygon.points.tolist() for polygon in polygons] # pyright: ignore[reportAttributeAccessIssue]
160
334
 
161
335
  predicted_polygon_data = _predict_poly(
162
336
  polygon_points=polygon_points,
@@ -165,7 +339,8 @@ def predict_gdstk(
165
339
  eta=eta,
166
340
  )
167
341
 
168
- result_cell = gdstk.Cell(f"{gdstk_cell.name}_predicted")
342
+ suffix = "corrected" if model_type == "c" else "predicted"
343
+ result_cell = gdstk.Cell(f"{gdstk_cell.name}_{suffix}")
169
344
 
170
345
  polygons_by_channel = {}
171
346
  for polygon_data in predicted_polygon_data:
@@ -183,129 +358,82 @@ def predict_gdstk(
183
358
 
184
359
  for points in points_list:
185
360
  points_array = np.array(points)
186
- polygon = gdstk.Polygon(points_array, layer=layer, datatype=datatype)
361
+ polygon = gdstk.Polygon(points_array, layer=layer, datatype=datatype) # pyright: ignore[reportArgumentType]
187
362
  result_cell.add(polygon)
188
363
 
189
364
  return result_cell
190
365
 
191
366
 
192
- def predict_array(
193
- device_array: np.ndarray,
194
- model: Model,
195
- model_type: str,
196
- binarize: bool,
197
- gpu: bool = False,
198
- ) -> np.ndarray:
367
+ def _predict_array_with_grad(
368
+ device_array: npt.NDArray[Any], model: Model
369
+ ) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]:
199
370
  """
200
- Predict the nanofabrication outcome of a device array using a specified model.
371
+ Predict the nanofabrication outcome of a device array and compute its gradient.
201
372
 
202
- This function sends the device array to a serverless prediction service, which uses
203
- a specified machine learning model to predict the outcome of the nanofabrication
204
- process. The prediction can be performed on a GPU if specified.
373
+ This function predicts the outcome of the nanofabrication process for a given
374
+ device array using a specified model. It also computes the gradient of the
375
+ prediction with respect to the input device array.
205
376
 
206
377
  Parameters
207
378
  ----------
208
379
  device_array : np.ndarray
209
- A 2D array representing the planar geometry of the device. This array undergoes
210
- various transformations to predict the nanofabrication process.
380
+ A 2D array representing the planar geometry of the device.
211
381
  model : Model
212
- The model to use for prediction, representing a specific fabrication process and
213
- dataset. This model encapsulates details about the fabrication foundry, process,
214
- material, technology, thickness, and sidewall presence, as defined in
215
- `models.py`. Each model is associated with a version and dataset that detail its
216
- creation and the data it was trained on, ensuring the prediction is tailored to
217
- specific fabrication parameters.
218
- model_type : str
219
- The type of model to use (e.g., 'p' for prediction, 'c' for correction, or 's'
220
- for SEMulate).
221
- binarize : bool
222
- If True, the predicted device geometry will be binarized using a threshold
223
- method. This is useful for converting probabilistic predictions into binary
224
- geometries.
225
- gpu : bool
226
- If True, the prediction will be performed on a GPU. Defaults to False. Note: The
227
- GPU option has more startup overhead and will take longer for small devices, but
228
- will be faster for larger devices.
382
+ The model to use for prediction, representing a specific fabrication process.
229
383
 
230
384
  Returns
231
385
  -------
232
- np.ndarray
233
- The predicted output array.
386
+ tuple[np.ndarray, np.ndarray]
387
+ The predicted output array and gradient array.
234
388
 
235
389
  Raises
236
390
  ------
237
391
  RuntimeError
238
392
  If the request to the prediction service fails.
393
+ ValueError
394
+ If the server returns an error or invalid response.
239
395
  """
240
396
  headers = _prepare_headers()
241
- predict_data = _prepare_predict_data(device_array, model, model_type, binarize)
242
- endpoint_url = (
243
- f"{BASE_ENDPOINT_URL}-gpu-v{ENDPOINT_VERSION}.modal.run"
244
- if gpu
245
- else f"{BASE_ENDPOINT_URL}-v{ENDPOINT_VERSION}.modal.run"
246
- )
397
+ predict_data = {
398
+ "device_array": _encode_array(np.squeeze(device_array)),
399
+ "model": model.to_json(),
400
+ "model_type": "p",
401
+ "binary": False,
402
+ }
403
+ endpoint_url = f"{BASE_ENDPOINT_URL}-with-grad-v{ENDPOINT_VERSION}.modal.run"
247
404
 
248
405
  try:
249
- with requests.post(
250
- endpoint_url,
251
- data=json.dumps(predict_data),
252
- headers=headers,
253
- stream=True,
254
- ) as response:
255
- response.raise_for_status()
256
- result = _process_response(response, model_type, binarize)
257
- if result is None:
258
- raise RuntimeError("No prediction result received.")
259
- return result
260
- except requests.RequestException as e:
261
- raise RuntimeError(f"Request failed: {e}") from e
262
-
263
-
264
- def _predict_array_with_grad(
265
- device_array: np.ndarray, model: Model
266
- ) -> tuple[np.ndarray, np.ndarray]:
267
- """
268
- Predict the nanofabrication outcome of a device array and compute its gradient.
406
+ response = requests.post(
407
+ endpoint_url, data=json.dumps(predict_data), headers=headers
408
+ )
409
+ response.raise_for_status()
269
410
 
270
- This function predicts the outcome of the nanofabrication process for a given
271
- device array using a specified model. It also computes the gradient of the
272
- prediction with respect to the input device array.
411
+ if not response.content:
412
+ raise ValueError("Empty response received from server")
273
413
 
274
- Notes
275
- -----
276
- This function is currently not used in the main `predict_array` function as
277
- the main `predict_array` function (e.g., GPU support and progress bar) for now.
414
+ response_data = response.json()
278
415
 
279
- Parameters
280
- ----------
281
- device_array : np.ndarray
282
- A 2D array representing the planar geometry of the device.
283
- model : Model
284
- The model to use for prediction, representing a specific fabrication process.
416
+ if "error" in response_data:
417
+ raise ValueError(f"Prediction error: {response_data['error']}")
285
418
 
286
- Returns
287
- -------
288
- tuple[np.ndarray, np.ndarray]
289
- The predicted output array and gradient array.
290
- """
291
- headers = _prepare_headers()
292
- predict_data = _prepare_predict_data(device_array, model, "p", False)
293
- endpoint_url = f"{BASE_ENDPOINT_URL}-with-grad-v{ENDPOINT_VERSION}.modal.run"
419
+ prediction_array = _decode_array(response_data["prediction_array"])
420
+ gradient_array = _decode_array(response_data["gradient_array"])
421
+ gradient_min = response_data["gradient_min"]
422
+ gradient_max = response_data["gradient_max"]
423
+ gradient_range = gradient_max - gradient_min
424
+ gradient_array = gradient_array * gradient_range + gradient_min
425
+ return (prediction_array, gradient_array)
294
426
 
295
- response = requests.post(
296
- endpoint_url, data=json.dumps(predict_data), headers=headers
297
- )
298
- prediction_array = _decode_array(response.json()["prediction_array"])
299
- gradient_array = _decode_array(response.json()["gradient_array"])
300
- gradient_min = response.json()["gradient_min"]
301
- gradient_max = response.json()["gradient_max"]
302
- gradient_range = gradient_max - gradient_min
303
- gradient_array = gradient_array * gradient_range + gradient_min
304
- return (prediction_array, gradient_array)
427
+ except requests.exceptions.RequestException as e:
428
+ raise RuntimeError(f"Request failed: {e}") from e
429
+ except json.JSONDecodeError as e:
430
+ raise ValueError(f"JSON decode error: {e}") from e
305
431
 
306
432
 
307
433
  @primitive
308
- def predict_array_with_grad(device_array: np.ndarray, model: Model) -> np.ndarray:
434
+ def predict_array_with_grad(
435
+ device_array: npt.NDArray[Any], model: Model
436
+ ) -> npt.NDArray[Any]:
309
437
  """
310
438
  Predict the nanofabrication outcome of a device array and compute its gradient.
311
439
 
@@ -325,15 +453,24 @@ def predict_array_with_grad(device_array: np.ndarray, model: Model) -> np.ndarra
325
453
  -------
326
454
  np.ndarray
327
455
  The predicted output array.
456
+
457
+ Raises
458
+ ------
459
+ RuntimeError
460
+ If the request to the prediction service fails.
461
+ ValueError
462
+ If the server returns an error or invalid response.
328
463
  """
329
464
  prediction_array, gradient_array = _predict_array_with_grad(
330
465
  device_array=device_array, model=model
331
466
  )
332
- predict_array_with_grad.gradient_array = gradient_array # type: ignore
467
+ predict_array_with_grad.gradient_array = gradient_array # pyright: ignore[reportFunctionMemberAccess]
333
468
  return prediction_array
334
469
 
335
470
 
336
- def predict_array_with_grad_vjp(ans: np.ndarray, device_array: np.ndarray, *args):
471
+ def predict_array_with_grad_vjp(
472
+ ans: npt.NDArray[Any], device_array: npt.NDArray[Any], *args: Any
473
+ ) -> Any:
337
474
  """
338
475
  Define the vector-Jacobian product (VJP) for the prediction function.
339
476
 
@@ -351,10 +488,10 @@ def predict_array_with_grad_vjp(ans: np.ndarray, device_array: np.ndarray, *args
351
488
  function
352
489
  A function that computes the VJP given an upstream gradient `g`.
353
490
  """
354
- grad_x = predict_array_with_grad.gradient_array # type: ignore
491
+ grad_x = predict_array_with_grad.gradient_array # pyright: ignore[reportFunctionMemberAccess]
355
492
 
356
- def vjp(g: np.ndarray) -> np.ndarray:
357
- return g * grad_x
493
+ def vjp(g: npt.NDArray[Any]) -> npt.NDArray[Any]:
494
+ return g * grad_x # type: ignore[no-any-return]
358
495
 
359
496
  return vjp
360
497
 
@@ -362,7 +499,7 @@ def predict_array_with_grad_vjp(ans: np.ndarray, device_array: np.ndarray, *args
362
499
  defvjp(predict_array_with_grad, predict_array_with_grad_vjp)
363
500
 
364
501
 
365
- def _encode_array(array):
502
+ def _encode_array(array: npt.NDArray[Any]) -> str:
366
503
  """Encode an ndarray as a base64 encoded image for transmission."""
367
504
  image = Image.fromarray(np.uint8(array * 255))
368
505
  buffered = io.BytesIO()
@@ -371,15 +508,15 @@ def _encode_array(array):
371
508
  return encoded_png
372
509
 
373
510
 
374
- def _decode_array(encoded_png):
511
+ def _decode_array(encoded_png: str) -> npt.NDArray[Any]:
375
512
  """Decode a base64 encoded image and return an ndarray."""
376
513
  binary_data = base64.b64decode(encoded_png)
377
514
  image = Image.open(io.BytesIO(binary_data))
378
- return np.array(image) / 255
515
+ return np.array(image) / 255 # type: ignore[no-any-return]
379
516
 
380
517
 
381
- def _read_tokens():
382
- """Read access and refresh tokens from the configuration file."""
518
+ def _prepare_headers() -> dict[str, str]:
519
+ """Prepare HTTP headers for a server request."""
383
520
  token_file_path = os.path.expanduser("~/.prefab.toml")
384
521
  try:
385
522
  with open(token_file_path) as file:
@@ -388,119 +525,14 @@ def _read_tokens():
388
525
  refresh_token = tokens.get("refresh_token")
389
526
  if not access_token or not refresh_token:
390
527
  raise ValueError("Tokens not found in the configuration file.")
391
- return access_token, refresh_token
528
+ return {
529
+ "Authorization": f"Bearer {access_token}",
530
+ "X-Refresh-Token": refresh_token,
531
+ }
392
532
  except FileNotFoundError:
393
533
  raise FileNotFoundError(
394
534
  "Could not validate user.\n"
395
- "Please update prefab using: pip install --upgrade prefab.\n"
396
- "Signup/login and generate a new token.\n"
397
- "See https://docs.prefabphotonics.com/."
535
+ + "Please update prefab using: pip install --upgrade prefab.\n"
536
+ + "Signup/login and generate a new token.\n"
537
+ + "See https://docs.prefabphotonics.com/."
398
538
  ) from None
399
-
400
-
401
- def _prepare_headers():
402
- """Prepare HTTP headers for the request."""
403
- access_token, refresh_token = _read_tokens()
404
- return {
405
- "Authorization": f"Bearer {access_token}",
406
- "X-Refresh-Token": refresh_token,
407
- }
408
-
409
-
410
- def _prepare_predict_data(device_array, model, model_type, binarize):
411
- """Prepare the data payload for the prediction request."""
412
- return {
413
- "device_array": _encode_array(np.squeeze(device_array)),
414
- "model": model.to_json(),
415
- "model_type": model_type,
416
- "binary": binarize,
417
- }
418
-
419
-
420
- def _process_response(response, model_type, binarize):
421
- """Process the streaming response from the prediction request."""
422
- event_type = None
423
- model_descriptions = {
424
- "p": "Prediction",
425
- "c": "Correction",
426
- "s": "SEMulate",
427
- }
428
- progress_bar = tqdm(
429
- total=100,
430
- desc=model_descriptions.get(model_type, "Processing"),
431
- unit="%",
432
- colour="green",
433
- bar_format="{l_bar}{bar:30}{r_bar}{bar:-10b}",
434
- )
435
-
436
- for line in response.iter_lines():
437
- if line:
438
- decoded_line = line.decode("utf-8").strip()
439
- if decoded_line.startswith("event:"):
440
- event_type = decoded_line.split(":", 1)[1].strip()
441
- elif decoded_line.startswith("data:"):
442
- data_content = _parse_data_line(decoded_line)
443
- result = _handle_event(event_type, data_content, progress_bar, binarize)
444
- if result is not None:
445
- progress_bar.close()
446
- return result
447
- progress_bar.close()
448
-
449
-
450
- def _parse_data_line(decoded_line):
451
- """Parse a data line from the response stream."""
452
- data_line = decoded_line.split(":", 1)[1].strip()
453
- try:
454
- return json.loads(data_line)
455
- except json.JSONDecodeError:
456
- raise ValueError(f"Failed to decode JSON: {data_line}") from None
457
-
458
-
459
- def _handle_event(event_type, data_content, progress_bar, binarize):
460
- """Handle different types of events received from the server."""
461
- if event_type == "progress":
462
- _update_progress(progress_bar, data_content)
463
- elif event_type == "result":
464
- return _process_result(data_content, binarize)
465
- elif event_type == "end":
466
- print("Stream ended.")
467
- elif event_type == "auth":
468
- _update_tokens(data_content.get("auth", {}))
469
- elif event_type == "error":
470
- raise ValueError(f"{data_content['error']}")
471
-
472
-
473
- def _update_progress(progress_bar, data_content):
474
- """Update the progress bar based on the progress event."""
475
- progress = round(100 * data_content.get("progress", 0))
476
- progress_bar.update(progress - progress_bar.n)
477
-
478
-
479
- def _process_result(data_content, binarize):
480
- """Process the result event and return the prediction."""
481
- results = [
482
- _decode_array(data_content[key])
483
- for key in sorted(data_content.keys())
484
- if key.startswith("result")
485
- ]
486
- if results:
487
- prediction = np.stack(results, axis=-1)
488
- if binarize:
489
- prediction = binarize_hard(prediction)
490
- return prediction
491
-
492
-
493
- def _update_tokens(auth_data):
494
- """Update tokens if new tokens are provided in the auth event."""
495
- new_access_token = auth_data.get("new_access_token")
496
- new_refresh_token = auth_data.get("new_refresh_token")
497
- if new_access_token and new_refresh_token:
498
- prefab_file_path = os.path.expanduser("~/.prefab.toml")
499
- with open(prefab_file_path, "w", encoding="utf-8") as toml_file:
500
- toml.dump(
501
- {
502
- "access_token": new_access_token,
503
- "refresh_token": new_refresh_token,
504
- },
505
- toml_file,
506
- )