podstack 1.3.17__tar.gz → 1.3.20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {podstack-1.3.17 → podstack-1.3.20}/PKG-INFO +1 -1
  2. {podstack-1.3.17 → podstack-1.3.20}/podstack/annotations.py +6 -3
  3. {podstack-1.3.17 → podstack-1.3.20}/podstack/gpu_runner.py +152 -54
  4. {podstack-1.3.17 → podstack-1.3.20}/podstack/registry/__init__.py +28 -5
  5. {podstack-1.3.17 → podstack-1.3.20}/podstack/registry/client.py +30 -7
  6. {podstack-1.3.17 → podstack-1.3.20}/podstack.egg-info/PKG-INFO +1 -1
  7. {podstack-1.3.17 → podstack-1.3.20}/pyproject.toml +1 -1
  8. {podstack-1.3.17 → podstack-1.3.20}/LICENSE +0 -0
  9. {podstack-1.3.17 → podstack-1.3.20}/README.md +0 -0
  10. {podstack-1.3.17 → podstack-1.3.20}/podstack/__init__.py +0 -0
  11. {podstack-1.3.17 → podstack-1.3.20}/podstack/client.py +0 -0
  12. {podstack-1.3.17 → podstack-1.3.20}/podstack/exceptions.py +0 -0
  13. {podstack-1.3.17 → podstack-1.3.20}/podstack/execution.py +0 -0
  14. {podstack-1.3.17 → podstack-1.3.20}/podstack/models.py +0 -0
  15. {podstack-1.3.17 → podstack-1.3.20}/podstack/notebook.py +0 -0
  16. {podstack-1.3.17 → podstack-1.3.20}/podstack/registry/autolog.py +0 -0
  17. {podstack-1.3.17 → podstack-1.3.20}/podstack/registry/exceptions.py +0 -0
  18. {podstack-1.3.17 → podstack-1.3.20}/podstack/registry/experiment.py +0 -0
  19. {podstack-1.3.17 → podstack-1.3.20}/podstack/registry/model.py +0 -0
  20. {podstack-1.3.17 → podstack-1.3.20}/podstack/registry/model_utils.py +0 -0
  21. {podstack-1.3.17 → podstack-1.3.20}/podstack.egg-info/SOURCES.txt +0 -0
  22. {podstack-1.3.17 → podstack-1.3.20}/podstack.egg-info/dependency_links.txt +0 -0
  23. {podstack-1.3.17 → podstack-1.3.20}/podstack.egg-info/requires.txt +0 -0
  24. {podstack-1.3.17 → podstack-1.3.20}/podstack.egg-info/top_level.txt +0 -0
  25. {podstack-1.3.17 → podstack-1.3.20}/podstack_gpu/__init__.py +0 -0
  26. {podstack-1.3.17 → podstack-1.3.20}/podstack_gpu/app.py +0 -0
  27. {podstack-1.3.17 → podstack-1.3.20}/podstack_gpu/exceptions.py +0 -0
  28. {podstack-1.3.17 → podstack-1.3.20}/podstack_gpu/image.py +0 -0
  29. {podstack-1.3.17 → podstack-1.3.20}/podstack_gpu/runner.py +0 -0
  30. {podstack-1.3.17 → podstack-1.3.20}/podstack_gpu/secret.py +0 -0
  31. {podstack-1.3.17 → podstack-1.3.20}/podstack_gpu/utils.py +0 -0
  32. {podstack-1.3.17 → podstack-1.3.20}/podstack_gpu/volume.py +0 -0
  33. {podstack-1.3.17 → podstack-1.3.20}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: podstack
3
- Version: 1.3.17
3
+ Version: 1.3.20
4
4
  Summary: Official Python SDK for Podstack GPU Notebook Platform
5
5
  Author-email: Podstack <support@podstack.ai>
6
6
  License-Expression: MIT
@@ -90,6 +90,7 @@ class GPUConfig:
90
90
  conda: Union[str, list] = None,
91
91
  requirements: str = None,
92
92
  use_uv: bool = False,
93
+ env_vars: Optional[Dict[str, str]] = None,
93
94
  remote: bool = None
94
95
  ):
95
96
  """
@@ -121,6 +122,7 @@ class GPUConfig:
121
122
  self.conda = conda
122
123
  self.requirements = requirements
123
124
  self.use_uv = use_uv
125
+ self.env_vars = env_vars or {}
124
126
  self.remote = remote if remote is not None else _remote_execution_enabled
125
127
 
126
128
  # Store in global config
@@ -138,6 +140,7 @@ class GPUConfig:
138
140
  "conda": conda,
139
141
  "requirements": requirements,
140
142
  "use_uv": use_uv,
143
+ "env_vars": env_vars or {},
141
144
  }
142
145
 
143
146
  def __call__(self, func: Callable) -> Callable:
@@ -155,9 +158,6 @@ class GPUConfig:
155
158
  print(f"[Podstack] GPU Config (local): {self.type} x{self.count} @ {self.fraction}%")
156
159
  return func(*args, **kwargs)
157
160
 
158
- # Remote execution on GPU
159
- print(f"[Podstack] Provisioning GPU: {self.type} x{self.count} @ {self.fraction}%")
160
-
161
161
  try:
162
162
  runner = get_runner()
163
163
  except ValueError as e:
@@ -249,6 +249,7 @@ if __podstack_result__ is not None:
249
249
  conda=effective_conda,
250
250
  requirements=self.requirements,
251
251
  use_uv=self.use_uv,
252
+ env_vars=self.env_vars,
252
253
  runner=runner_name,
253
254
  wait=True,
254
255
  stream=None # Auto-detect: True in Jupyter, False otherwise
@@ -340,6 +341,7 @@ def gpu(
340
341
  conda: Union[str, list] = None,
341
342
  requirements: str = None,
342
343
  use_uv: bool = False,
344
+ env_vars: Optional[Dict[str, str]] = None,
343
345
  remote: bool = None
344
346
  ) -> GPUConfig:
345
347
  """
@@ -421,6 +423,7 @@ def gpu(
421
423
  conda=conda,
422
424
  requirements=requirements,
423
425
  use_uv=use_uv,
426
+ env_vars=env_vars,
424
427
  remote=remote
425
428
  )
426
429
 
@@ -18,6 +18,99 @@ import httpx
18
18
  # Configure logging
19
19
  logger = logging.getLogger("podstack.gpu_runner")
20
20
 
21
+ SPINNER_FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
22
+
23
+
24
+ class LiveDisplay:
25
+ """Animated phase display for GPU runner lifecycle in Jupyter + terminal."""
26
+
27
+ PHASES = {
28
+ "pending": ("🔍", "Searching for GPU..."),
29
+ "queued": ("📋", "Queued — waiting for available GPU..."),
30
+ "provisioning": ("🚀", "Allocating GPU pod..."),
31
+ "running": None, # No spinner — logs stream directly
32
+ }
33
+ CHECKMARKS = {
34
+ "pending": "🔍 Submitted",
35
+ "queued": "📋 In queue",
36
+ "provisioning": "✓ GPU pod provisioning...",
37
+ "running": "✓ Pod ready — logging live output:",
38
+ }
39
+
40
+ def __init__(self, gpu_type: str, gpu_count: int, fraction: int):
41
+ self._gpu_type = gpu_type
42
+ self._gpu_count = gpu_count
43
+ self._fraction = fraction
44
+ self._phase = None
45
+ self._spinner_thread: Optional[threading.Thread] = None
46
+ self._stop_evt = threading.Event()
47
+ self._logs_started = False
48
+
49
+ def set_phase(self, status: str, extra: str = ""):
50
+ """Transition to a new lifecycle phase."""
51
+ self._stop_spinner()
52
+ if self._phase and self._phase != status:
53
+ label = self.CHECKMARKS.get(self._phase, f"✓ {self._phase}")
54
+ self._print(f"\r[Podstack] {label}{' ' * 30}\n")
55
+ self._phase = status
56
+ if status == "running":
57
+ if not self._logs_started:
58
+ self._logs_started = True
59
+ self._print(f"[Podstack] ─── Live Logs ({self._gpu_type} x{self._gpu_count}) ───\n\n")
60
+ else:
61
+ phase_info = self.PHASES.get(status)
62
+ if phase_info:
63
+ _, label = phase_info
64
+ if extra:
65
+ label = f"{label} {extra}"
66
+ self._start_spinner(f"[Podstack] {{spinner}} {label}")
67
+
68
+ def log(self, content: str, output_type: str = "stdout"):
69
+ """Write a log line from the pod (indented, real-time)."""
70
+ if output_type == "stderr":
71
+ sys.stderr.write(f" {content}" if not content.startswith(" ") else content)
72
+ sys.stderr.flush()
73
+ else:
74
+ sys.stdout.write(f" {content}" if not content.startswith(" ") else content)
75
+ sys.stdout.flush()
76
+
77
+ def complete(self, success: bool, gpu_seconds: float, cost_paise: int, error: str = None):
78
+ """Print final summary line."""
79
+ self._stop_spinner()
80
+ if self._logs_started:
81
+ self._print(f"\n[Podstack] ─────────────────────────────────────\n")
82
+ if success:
83
+ cost_str = f"₹{cost_paise/100:.2f}" if cost_paise else ""
84
+ self._print(f"[Podstack] ✓ Completed in {gpu_seconds:.1f}s | {self._gpu_type} x{self._gpu_count} | {cost_str}\n")
85
+ else:
86
+ self._print(f"[Podstack] ✗ Failed: {error}\n")
87
+
88
+ def _start_spinner(self, template: str):
89
+ self._stop_evt.clear()
90
+
91
+ def _spin():
92
+ i = 0
93
+ while not self._stop_evt.is_set():
94
+ frame = SPINNER_FRAMES[i % len(SPINNER_FRAMES)]
95
+ sys.stdout.write(f"\r{template.format(spinner=frame)} ")
96
+ sys.stdout.flush()
97
+ self._stop_evt.wait(0.1)
98
+ i += 1
99
+
100
+ self._spinner_thread = threading.Thread(target=_spin, daemon=True)
101
+ self._spinner_thread.start()
102
+
103
+ def _stop_spinner(self):
104
+ if self._spinner_thread and self._spinner_thread.is_alive():
105
+ self._stop_evt.set()
106
+ self._spinner_thread.join(timeout=0.5)
107
+ self._spinner_thread = None
108
+
109
+ @staticmethod
110
+ def _print(msg: str):
111
+ sys.stdout.write(msg)
112
+ sys.stdout.flush()
113
+
21
114
 
22
115
  def is_jupyter() -> bool:
23
116
  """Check if running in a Jupyter notebook."""
@@ -361,6 +454,7 @@ class GPURunner:
361
454
  conda: Union[str, list] = None,
362
455
  requirements: str = None,
363
456
  use_uv: bool = False,
457
+ env_vars: Dict[str, str] = None,
364
458
  add_annotation: bool = True,
365
459
  runner: str = None
366
460
  ) -> Dict[str, Any]:
@@ -395,6 +489,13 @@ class GPURunner:
395
489
  annotation = self._build_annotation(gpu, count, fraction, timeout, env, pip, uv, conda, requirements, use_uv, runner)
396
490
  code = f"{annotation}\n\n{code}"
397
491
 
492
+ # Inject environment variables
493
+ if env_vars:
494
+ env_lines = ["import os"]
495
+ for k, v in env_vars.items():
496
+ env_lines.append(f"os.environ[{repr(k)}] = {repr(str(v))}")
497
+ code = "\n".join(env_lines) + "\n\n" + code
498
+
398
499
  # Build installation code for packages
399
500
  install_parts = []
400
501
 
@@ -672,6 +773,7 @@ _stream_install(
672
773
  conda: Union[str, list] = None,
673
774
  requirements: str = None,
674
775
  use_uv: bool = False,
776
+ env_vars: Dict[str, str] = None,
675
777
  wait: bool = True,
676
778
  poll_interval: float = 2.0,
677
779
  max_retries: int = 3,
@@ -713,14 +815,12 @@ _stream_install(
713
815
  ValueError: If parameters are invalid
714
816
  """
715
817
  # Submit the code
716
- submission = self.submit(code, gpu, count, fraction, timeout, env, pip, uv, conda, requirements, use_uv, runner=runner)
818
+ submission = self.submit(code, gpu, count, fraction, timeout, env, pip, uv, conda, requirements, use_uv, env_vars=env_vars, runner=runner)
717
819
  execution_id = submission.get("execution_id")
718
820
 
719
821
  if not execution_id:
720
822
  raise RuntimeError(f"No execution_id in response: {submission}")
721
823
 
722
- print(f"[Podstack] Execution submitted: {execution_id}")
723
-
724
824
  if not wait:
725
825
  return GPUExecutionResult(
726
826
  execution_id=execution_id,
@@ -733,9 +833,9 @@ _stream_install(
733
833
  should_stream = stream if stream is not None else is_jupyter()
734
834
 
735
835
  if should_stream:
736
- return self._run_with_streaming(execution_id, gpu, count, timeout, max_retries, cancel_on_timeout)
836
+ return self._run_with_streaming(execution_id, gpu, count, timeout, max_retries, cancel_on_timeout, fraction)
737
837
  else:
738
- return self._run_with_polling(execution_id, gpu, count, timeout, poll_interval, max_retries, provisioning_timeout, cancel_on_timeout)
838
+ return self._run_with_polling(execution_id, gpu, count, timeout, poll_interval, max_retries, provisioning_timeout, cancel_on_timeout, fraction)
739
839
 
740
840
  def _run_with_streaming(
741
841
  self,
@@ -744,10 +844,12 @@ _stream_install(
744
844
  count: int,
745
845
  timeout: int,
746
846
  max_retries: int,
747
- cancel_on_timeout: bool
847
+ cancel_on_timeout: bool,
848
+ fraction: int = 100
748
849
  ) -> GPUExecutionResult:
749
850
  """Run execution with real-time output streaming."""
750
- print(f"[Podstack] Waiting for GPU runner ({gpu} x{count})...")
851
+ display = LiveDisplay(gpu, count, fraction)
852
+ display.set_phase("pending")
751
853
 
752
854
  start_time = time.time()
753
855
  output_buffer = []
@@ -755,13 +857,14 @@ _stream_install(
755
857
  final_event = {}
756
858
 
757
859
  try:
758
- for event in self.stream_output(execution_id, show_output=True):
860
+ for event in self.stream_output(execution_id, show_output=False):
759
861
  elapsed = time.time() - start_time
760
862
  if elapsed > timeout:
761
863
  if cancel_on_timeout:
762
864
  try:
763
865
  self.cancel(execution_id)
764
- print(f"\n[Podstack] Execution cancelled due to timeout")
866
+ display._stop_spinner()
867
+ display._print(f"\r[Podstack] Execution cancelled due to timeout{' ' * 30}\n")
765
868
  except Exception as e:
766
869
  logger.warning(f"Failed to cancel execution: {e}")
767
870
 
@@ -778,37 +881,41 @@ _stream_install(
778
881
  )
779
882
  )
780
883
 
781
- # Track output
782
- if event.get("type") in ("stdout", "stderr", "output"):
783
- content = event.get("content", "")
784
- if content:
785
- output_buffer.append(content)
786
-
787
- # Track status
884
+ # Track status transitions
788
885
  if "status" in event:
789
886
  new_status = event["status"]
790
887
  if new_status != final_status:
791
888
  final_status = new_status
792
- if final_status == "provisioning":
793
- print(f"\n[Podstack] Provisioning GPU runner...")
794
- elif final_status == "running":
795
- print(f"\n[Podstack] Running on GPU...")
889
+ extra = ""
890
+ if new_status == "queued":
891
+ pos = event.get("queue_position", "?")
892
+ extra = f"(position: {pos})"
893
+ display.set_phase(new_status, extra)
796
894
 
797
895
  # Check for terminal status
798
896
  if final_status in ("completed", "failed", "timeout", "cancelled"):
799
897
  final_event = event
800
898
  break
801
899
 
900
+ # Stream output lines
901
+ if event.get("type") in ("stdout", "stderr", "output"):
902
+ content = event.get("content", "")
903
+ if content:
904
+ output_buffer.append(content)
905
+ display.log(content, event.get("type", "stdout"))
906
+
802
907
  except RuntimeError as e:
803
908
  if "HTTP 401" in str(e):
804
909
  # Auth failed on stream — fall back to polling
805
- print(f"\n[Podstack] Streaming auth failed, falling back to polling...")
806
- return self._run_with_polling(execution_id, gpu, count, timeout, 2.0, max_retries, 300, cancel_on_timeout)
910
+ display._stop_spinner()
911
+ display._print(f"\r[Podstack] Streaming auth failed, falling back to polling...{' ' * 10}\n")
912
+ return self._run_with_polling(execution_id, gpu, count, timeout, 2.0, max_retries, 300, cancel_on_timeout, fraction)
807
913
  raise
808
914
  except (ConnectionError, httpx.ConnectError) as e:
809
915
  # Try to recover and get the result
810
916
  logger.warning(f"Stream connection lost: {e}")
811
- print(f"\n[Podstack] Stream connection lost, fetching final result...")
917
+ display._stop_spinner()
918
+ display._print(f"\r[Podstack] Stream connection lost, fetching final result...{' ' * 10}\n")
812
919
 
813
920
  # Get final result
814
921
  result = None
@@ -828,11 +935,7 @@ _stream_install(
828
935
  if "__PODSTACK_RESULT__" not in result.output and "__PODSTACK_RESULT__" in streamed:
829
936
  result.output = streamed
830
937
 
831
- if result.success:
832
- print(f"\n[Podstack] Completed in {result.gpu_seconds:.1f}s (cost: ₹{result.cost_paise/100:.2f})")
833
- else:
834
- error_msg = result.error or 'Unknown error'
835
- print(f"\n[Podstack] Failed: {error_msg}")
938
+ display.complete(result.success, result.gpu_seconds, result.cost_paise, result.error or "Unknown error")
836
939
 
837
940
  return result
838
941
 
@@ -845,10 +948,13 @@ _stream_install(
845
948
  poll_interval: float,
846
949
  max_retries: int,
847
950
  provisioning_timeout: int,
848
- cancel_on_timeout: bool
951
+ cancel_on_timeout: bool,
952
+ fraction: int = 100
849
953
  ) -> GPUExecutionResult:
850
954
  """Run execution with polling (non-streaming mode)."""
851
- print(f"[Podstack] Waiting for GPU runner ({gpu} x{count})...")
955
+ display = LiveDisplay(gpu, count, fraction)
956
+ display.set_phase("pending")
957
+
852
958
  start_time = time.time()
853
959
  provisioning_start = None
854
960
  last_status = ""
@@ -863,7 +969,8 @@ _stream_install(
863
969
  if cancel_on_timeout:
864
970
  try:
865
971
  self.cancel(execution_id)
866
- print(f"[Podstack] Execution cancelled due to timeout")
972
+ display._stop_spinner()
973
+ display._print(f"\r[Podstack] Execution cancelled due to timeout{' ' * 30}\n")
867
974
  except Exception as e:
868
975
  logger.warning(f"Failed to cancel execution: {e}")
869
976
 
@@ -905,24 +1012,17 @@ _stream_install(
905
1012
 
906
1013
  if current_status != last_status:
907
1014
  last_status = current_status
908
- if current_status == "pending":
909
- print(f"[Podstack] Pending...")
910
- elif current_status == "queued":
911
- pos = status_data.get("queue_position", "?")
912
- print(f"[Podstack] Queued (position: {pos})")
913
- elif current_status == "provisioning":
1015
+ if current_status == "provisioning":
914
1016
  provisioning_start = time.time()
915
- print(f"[Podstack] Provisioning GPU runner...")
916
- elif current_status == "running":
917
- print(f"[Podstack] Running on GPU...")
918
- elif current_status == "streaming":
919
- print(f"[Podstack] Streaming output...")
920
- elif current_status in ("completed", "failed", "timeout", "cancelled"):
921
- pass # Terminal states handled below
922
- else:
1017
+ if current_status in ("pending", "queued", "provisioning", "running"):
1018
+ extra = ""
1019
+ if current_status == "queued":
1020
+ pos = status_data.get("queue_position", "?")
1021
+ extra = f"(position: {pos})"
1022
+ display.set_phase(current_status, extra)
1023
+ elif current_status not in ("completed", "failed", "timeout", "cancelled"):
923
1024
  # Unknown status - log but continue
924
1025
  logger.warning(f"Unknown status: {current_status}")
925
- print(f"[Podstack] Status: {current_status}")
926
1026
 
927
1027
  # Check for provisioning timeout
928
1028
  if provisioning_start and current_status == "provisioning":
@@ -965,14 +1065,12 @@ _stream_install(
965
1065
  raise ConnectionError(f"Failed to get result after {max_retries} attempts: {e}")
966
1066
  time.sleep(poll_interval * (attempt + 1))
967
1067
 
968
- if result.success:
969
- print(f"[Podstack] Completed in {result.gpu_seconds:.1f}s (cost: ₹{result.cost_paise/100:.2f})")
970
- else:
971
- error_msg = result.error or 'Unknown error'
972
- print(f"[Podstack] Failed: {error_msg}")
973
- # Include partial output in the error for debugging
974
- if result.output:
975
- print(f"[Podstack] Output (last 500 chars):\n{result.output[-500:]}")
1068
+ display.complete(result.success, result.gpu_seconds, result.cost_paise, result.error or "Unknown error")
1069
+
1070
+ # In polling mode, output isn't streamed — print it after completion
1071
+ if result.output and not result.success:
1072
+ sys.stdout.write(f"[Podstack] Output (last 500 chars):\n{result.output[-500:]}\n")
1073
+ sys.stdout.flush()
976
1074
 
977
1075
  return result
978
1076
 
@@ -176,7 +176,13 @@ def list_experiments(limit: int = 20, offset: int = 0) -> list:
176
176
  return _get_client().list_experiments(limit, offset)
177
177
 
178
178
 
179
- def start_run(name: str = None, tags: dict = None) -> Run:
179
+ def start_run(
180
+ name: str = None,
181
+ tags: dict = None,
182
+ capture_env: bool = True,
183
+ system_metrics: bool = True,
184
+ system_metrics_interval: float = 10.0,
185
+ ) -> Run:
180
186
  """
181
187
  Start a new run in the active experiment.
182
188
 
@@ -188,11 +194,14 @@ def start_run(name: str = None, tags: dict = None) -> Run:
188
194
  Args:
189
195
  name: Optional run name
190
196
  tags: Optional tags dict
197
+ capture_env: Auto-capture Python/pip/git/CUDA as _env.* params
198
+ system_metrics: Log CPU/RAM/GPU metrics every system_metrics_interval seconds
199
+ system_metrics_interval: Seconds between metric samples (default 10)
191
200
 
192
201
  Returns:
193
202
  Run object (context manager)
194
203
  """
195
- return _get_client().start_run(name, tags)
204
+ return _get_client().start_run(name, tags, capture_env, system_metrics, system_metrics_interval)
196
205
 
197
206
 
198
207
  def end_run(status: str = "completed"):
@@ -558,9 +567,23 @@ def create_trial(sweep_id: str, run_id: str, params: dict) -> dict:
558
567
  return _get_client().create_trial(sweep_id, run_id, params)
559
568
 
560
569
 
561
- def complete_trial(sweep_id: str, trial_id: str, value: float) -> None:
562
- """Mark a trial as completed with its objective metric value."""
563
- _get_client().complete_trial(sweep_id, trial_id, value)
570
+ def complete_trial(
571
+ sweep_id: str,
572
+ trial_id: str,
573
+ value: float = None,
574
+ metrics: dict = None,
575
+ run_id: str = None,
576
+ ) -> None:
577
+ """Mark a trial as completed with its objective metric value.
578
+
579
+ Args:
580
+ sweep_id: Sweep ID.
581
+ trial_id: Trial ID (from suggest_trial_params).
582
+ value: Scalar objective value.
583
+ metrics: Dict of metric name → value; first entry used if value not set.
584
+ run_id: Optional run ID to link to this trial.
585
+ """
586
+ _get_client().complete_trial(sweep_id, trial_id, value, metrics, run_id)
564
587
 
565
588
 
566
589
  def list_trials(sweep_id: str) -> list:
@@ -1372,9 +1372,12 @@ class RegistryClient:
1372
1372
  return data.get("sweeps", [])
1373
1373
 
1374
1374
  def suggest_trial_params(self, sweep_id: str) -> Dict[str, Any]:
1375
- """Get suggested hyperparameter values for the next trial."""
1375
+ """Get suggested hyperparameter values for the next trial.
1376
+
1377
+ Returns a dict with ``id`` (trial ID) and ``params`` (hyperparameter map).
1378
+ """
1376
1379
  data = self._request("GET", f"/sweeps/{sweep_id}/suggest")
1377
- return data.get("params", data)
1380
+ return data.get("trial", data)
1378
1381
 
1379
1382
  def create_trial(
1380
1383
  self,
@@ -1389,11 +1392,31 @@ class RegistryClient:
1389
1392
  })
1390
1393
  return data.get("trial", data)
1391
1394
 
1392
- def complete_trial(self, sweep_id: str, trial_id: str, value: float) -> None:
1393
- """Mark a trial as completed with its objective metric value."""
1394
- self._request("POST", f"/sweeps/{sweep_id}/trials/{trial_id}/complete", json={
1395
- "value": value
1396
- })
1395
+ def complete_trial(
1396
+ self,
1397
+ sweep_id: str,
1398
+ trial_id: str,
1399
+ value: float = None,
1400
+ metrics: Dict[str, float] = None,
1401
+ run_id: str = None,
1402
+ ) -> None:
1403
+ """Mark a trial as completed with its objective metric value.
1404
+
1405
+ Args:
1406
+ sweep_id: Sweep ID.
1407
+ trial_id: Trial ID (from suggest_trial_params).
1408
+ value: Scalar objective value.
1409
+ metrics: Dict of metric name → value; first entry used if value not set.
1410
+ run_id: Optional run ID to link to this trial.
1411
+ """
1412
+ body: Dict[str, Any] = {}
1413
+ if value is not None:
1414
+ body["value"] = value
1415
+ if metrics:
1416
+ body["metrics"] = metrics
1417
+ if run_id:
1418
+ body["run_id"] = run_id
1419
+ self._request("POST", f"/sweeps/{sweep_id}/trials/{trial_id}/complete", json=body)
1397
1420
 
1398
1421
  def list_trials(self, sweep_id: str) -> List[Dict[str, Any]]:
1399
1422
  """List all trials for a sweep."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: podstack
3
- Version: 1.3.17
3
+ Version: 1.3.20
4
4
  Summary: Official Python SDK for Podstack GPU Notebook Platform
5
5
  Author-email: Podstack <support@podstack.ai>
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "podstack"
7
- version = "1.3.17"
7
+ version = "1.3.20"
8
8
  description = "Official Python SDK for Podstack GPU Notebook Platform"
9
9
  readme = "README.md"
10
10
  license = "MIT"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes