prefab 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prefab/__init__.py CHANGED
@@ -5,7 +5,7 @@ Usage:
5
5
  import prefab as pf
6
6
  """
7
7
 
8
- __version__ = "1.5.0"
8
+ __version__ = "1.6.0"
9
9
 
10
10
  from . import compare, geometry, predict, read, shapes
11
11
  from .device import BufferSpec, Device
prefab/predict.py CHANGED
@@ -9,6 +9,7 @@ differentiation.
9
9
  """
10
10
 
11
11
  import base64
12
+ import gzip
12
13
  import io
13
14
  import json
14
15
  import os
@@ -384,9 +385,13 @@ def _compute_vjp(
384
385
  ) -> npt.NDArray[Any]:
385
386
  """Compute J.T @ upstream_gradient via the server-side VJP endpoint."""
386
387
  headers = _prepare_headers()
388
+ upstream_arr = np.squeeze(upstream_gradient).astype(np.float32)
387
389
  vjp_data = {
388
390
  "device_array": _encode_array(np.squeeze(device_array)),
389
- "upstream_gradient": _encode_array(np.squeeze(upstream_gradient)),
391
+ "upstream_gradient": base64.b64encode(
392
+ gzip.compress(upstream_arr.tobytes(), compresslevel=1)
393
+ ).decode("utf-8"),
394
+ "upstream_gradient_shape": list(upstream_arr.shape),
390
395
  "model": model.to_json(),
391
396
  "model_type": "p",
392
397
  }
@@ -473,6 +478,8 @@ def _predict_array_diff_vjp(
473
478
  )
474
479
  # Clean up cache
475
480
  _diff_cache.pop(cache_key, None)
481
+ # Ensure gradient shape matches input shape
482
+ vjp_result = vjp_result.reshape(cached_device_array.shape)
476
483
  # Return gradient for device_array, None for model (not differentiable)
477
484
  return (vjp_result, None)
478
485
 
@@ -487,6 +494,69 @@ predict_array_with_grad = predict_array_diff
487
494
  """Alias for predict_array_diff. Deprecated, use predict_array_diff directly."""
488
495
 
489
496
 
497
+ def differentiable(model: Model):
498
+ """
499
+ Create a model-bound differentiable predictor for clean autograd integration.
500
+
501
+ Returns a function that takes only `device_array` as input, enabling seamless
502
+ composition with other differentiable functions. The VJP returns a single
503
+ gradient array (not a tuple), making it compatible with standard autograd workflows.
504
+
505
+ Parameters
506
+ ----------
507
+ model : Model
508
+ The model to use for prediction.
509
+
510
+ Returns
511
+ -------
512
+ callable
513
+ A differentiable prediction function that takes `device_array` and returns
514
+ the predicted fabrication outcome.
515
+
516
+ Examples
517
+ --------
518
+ >>> predictor = pf.predict.differentiable(model)
519
+ >>> def loss_fn(x):
520
+ ... pred = predictor(x)
521
+ ... return np.mean((pred - target) ** 2)
522
+ >>> gradient = grad(loss_fn)(device_array) # Returns array, not tuple
523
+ """
524
+
525
+ @primitive
526
+ def predict(device_array: npt.NDArray[Any]) -> npt.NDArray[Any]:
527
+ prediction = predict_array(
528
+ device_array=device_array,
529
+ model=model,
530
+ model_type="p",
531
+ binarize=False,
532
+ )
533
+ _diff_cache[id(prediction)] = (device_array.copy(), model)
534
+ return prediction
535
+
536
+ def predict_vjp(
537
+ ans: npt.NDArray[Any], device_array: npt.NDArray[Any]
538
+ ) -> Any:
539
+ cache_key = id(ans)
540
+ cached_device_array, cached_model = _diff_cache.get(
541
+ cache_key, (device_array, model)
542
+ )
543
+
544
+ def vjp(g: npt.NDArray[Any]) -> npt.NDArray[Any]:
545
+ vjp_result = _compute_vjp(
546
+ device_array=cached_device_array,
547
+ upstream_gradient=g,
548
+ model=cached_model,
549
+ )
550
+ _diff_cache.pop(cache_key, None)
551
+ # Ensure gradient shape matches input shape
552
+ return vjp_result.reshape(device_array.shape)
553
+
554
+ return vjp
555
+
556
+ defvjp(predict, predict_vjp)
557
+ return predict
558
+
559
+
490
560
  def _encode_array(array: npt.NDArray[Any]) -> str:
491
561
  """Encode an ndarray as a base64 encoded image for transmission."""
492
562
  image = Image.fromarray(np.uint8(array * 255))
@@ -500,7 +570,7 @@ def _decode_array(encoded_png: str) -> npt.NDArray[Any]:
500
570
  """Decode a base64 encoded image and return an ndarray."""
501
571
  binary_data = base64.b64decode(encoded_png)
502
572
  image = Image.open(io.BytesIO(binary_data))
503
- return np.array(image) / 255 # type: ignore[no-any-return]
573
+ return np.array(image) / 255
504
574
 
505
575
 
506
576
  def _prepare_headers() -> dict[str, str]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: prefab
3
- Version: 1.5.0
3
+ Version: 1.6.0
4
4
  Summary: Artificial nanofabrication of integrated photonic circuits using deep learning
5
5
  Project-URL: Homepage, https://prefabphotonics.com
6
6
  Project-URL: Repository, https://github.com/PreFab-Photonics/PreFab
@@ -1,17 +1,17 @@
1
- prefab/__init__.py,sha256=shfnU_-iDY4MavJaXVb-Xzr9Y7w2_9zhShi4INgEPg4,425
1
+ prefab/__init__.py,sha256=WmddNmU7gmyH4cqkT6BW4iQ3RDG_l5SpKxg4poS6wsU,425
2
2
  prefab/__main__.py,sha256=uL6AdCeimPbXiWv0gnq9TDeZhAIrxGwJsEKNYbg9MZg,3454
3
3
  prefab/compare.py,sha256=aX7nr9tznSebYeeztvqIPz57npnJ4-iUeKEedrZdksE,3676
4
4
  prefab/device.py,sha256=1O6vTOq4wQRGVYvFWLH0uj1XhhYCfnDnIapDEYnBKHw,47996
5
5
  prefab/geometry.py,sha256=-nTaGjdw3KN1SVoyvqdcrE2GJP7OqPF6ivUhrO78rUk,11244
6
- prefab/predict.py,sha256=xEpzyS8uRMZFohBlTGmL90LmVsIn_41ZN7RRhwsw7Ww,18413
6
+ prefab/predict.py,sha256=gxxHCcaMv2TE5O639mQzO3IzFRsLGjY-FQowVp5vLkI,20659
7
7
  prefab/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  prefab/read.py,sha256=5BmvFemobA72urxs4j2VZRVvanZZGu1mDB1Uh-epyvI,8635
9
9
  prefab/shapes.py,sha256=mRGwsPS-A9XsW3jgvUuMCEeNv9BLXsEnJkytlIHUAKE,28802
10
10
  prefab/models/__init__.py,sha256=rRrjcOcHPcob98Coksc0tbvYcXbm6SoLEb-Md233Jvo,1391
11
11
  prefab/models/base.py,sha256=t4VNMsOztPedj3kN5fZ1-4tk0SRHWrMuqnIVHztsCs4,1514
12
12
  prefab/models/evaluation.py,sha256=2_Klui6tY8xPvOSVD8VpZCVAnT1RX15FONqWG-_x-J8,484
13
- prefab-1.5.0.dist-info/METADATA,sha256=f8778_o8nPf4DIewUWqsK_2VfImxJish5I390Y9mtww,33754
14
- prefab-1.5.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
15
- prefab-1.5.0.dist-info/entry_points.txt,sha256=h1_A9O9F3NAIoKXD1RPb3Eo-WCSiHhMB_AnagBi6XTQ,48
16
- prefab-1.5.0.dist-info/licenses/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
- prefab-1.5.0.dist-info/RECORD,,
13
+ prefab-1.6.0.dist-info/METADATA,sha256=CvuP1Khblh_LA-vx5NLgTH5A_zaewTkklkaPEsIY4hI,33754
14
+ prefab-1.6.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
15
+ prefab-1.6.0.dist-info/entry_points.txt,sha256=h1_A9O9F3NAIoKXD1RPb3Eo-WCSiHhMB_AnagBi6XTQ,48
16
+ prefab-1.6.0.dist-info/licenses/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
+ prefab-1.6.0.dist-info/RECORD,,
File without changes