langfun 0.1.2.dev202509240805__py3-none-any.whl → 0.1.2.dev202509260805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

@@ -29,6 +29,7 @@ import time
29
29
  from typing import Annotated, Any, Callable, Iterator, Sequence, Type
30
30
 
31
31
  from langfun.env import interface
32
+ from langfun.env.event_handlers import base as event_handler_base
32
33
  import pyglove as pg
33
34
 
34
35
 
@@ -36,7 +37,7 @@ class BaseSandbox(interface.Sandbox):
36
37
  """Base class for a sandbox."""
37
38
 
38
39
  id: Annotated[
39
- interface.SandboxId,
40
+ interface.Sandbox.Id,
40
41
  'The identifier for the sandbox.'
41
42
  ]
42
43
 
@@ -92,10 +93,12 @@ class BaseSandbox(interface.Sandbox):
92
93
  assert self._status != status, (self._status, status)
93
94
  self.on_status_change(self._status, status)
94
95
  self._status = status
96
+ self._status_start_time = time.time()
95
97
 
96
- def _maybe_report_state_error(self, e: BaseException | None) -> None:
98
+ def report_maybe_state_error(self, e: BaseException | None) -> None:
97
99
  """Reports sandbox state errors."""
98
- if isinstance(e, interface.SandboxStateError):
100
+ if (isinstance(e, interface.SandboxStateError)
101
+ and e not in self._state_errors):
99
102
  self._state_errors.append(e)
100
103
 
101
104
  def _setup_features(self) -> None:
@@ -133,7 +136,7 @@ class BaseSandbox(interface.Sandbox):
133
136
  try:
134
137
  feature.teardown()
135
138
  except BaseException as e: # pylint: disable=broad-except
136
- self._maybe_report_state_error(e)
139
+ self.report_maybe_state_error(e)
137
140
  errors[feature.name] = e
138
141
  if errors:
139
142
  return interface.FeatureTeardownError(sandbox=self, errors=errors)
@@ -167,7 +170,7 @@ class BaseSandbox(interface.Sandbox):
167
170
  try:
168
171
  feature.teardown_session()
169
172
  except BaseException as e: # pylint: disable=broad-except
170
- self._maybe_report_state_error(e)
173
+ self.report_maybe_state_error(e)
171
174
  feature_teardown_errors[name] = e
172
175
 
173
176
  return interface.SessionTeardownError(
@@ -190,7 +193,6 @@ class BaseSandbox(interface.Sandbox):
190
193
  for name, feature in self.environment.features.items()
191
194
  })
192
195
  self._event_handlers = []
193
-
194
196
  self._enable_pre_session_setup = (
195
197
  self.reusable and self.proactive_session_setup
196
198
  )
@@ -202,18 +204,24 @@ class BaseSandbox(interface.Sandbox):
202
204
  )
203
205
  )
204
206
  self._housekeep_thread = None
205
- self._housekeep_count = 0
207
+ self._housekeep_counter = 0
206
208
 
207
209
  # Runtime state.
208
210
  self._status = self.Status.CREATED
211
+ self._status_start_time = time.time()
212
+
209
213
  self._start_time = None
210
214
  self._state_errors = []
215
+
211
216
  self._features_with_setup_called = set()
212
217
  self._features_with_setup_session_called = set()
213
218
 
214
219
  self._session_id = None
215
220
  self._session_start_time = None
216
221
 
222
+ # Thread local state for this sandbox.
223
+ self._tls_state = threading.local()
224
+
217
225
  @functools.cached_property
218
226
  def working_dir(self) -> str | None:
219
227
  """Returns the working directory for the sandbox."""
@@ -228,16 +236,21 @@ class BaseSandbox(interface.Sandbox):
228
236
  """Marks the sandbox as acquired."""
229
237
  self._set_status(self.Status.ACQUIRED)
230
238
 
239
+ @property
240
+ def housekeep_counter(self) -> int:
241
+ """Returns the housekeeping counter."""
242
+ return self._housekeep_counter
243
+
231
244
  def add_event_handler(
232
245
  self,
233
- event_handler: interface.EnvironmentEventHandler | None
246
+ event_handler: event_handler_base.EventHandler | None
234
247
  ) -> None:
235
248
  """Sets the event handler for the sandbox."""
236
249
  self._event_handlers.append(event_handler)
237
250
 
238
251
  def remove_event_handler(
239
252
  self,
240
- event_handler: interface.EnvironmentEventHandler | None
253
+ event_handler: event_handler_base.EventHandler | None
241
254
  ) -> None:
242
255
  """Removes the event handler for the sandbox."""
243
256
  self._event_handlers.remove(event_handler)
@@ -247,11 +260,40 @@ class BaseSandbox(interface.Sandbox):
247
260
  """Returns all errors encountered during sandbox lifecycle."""
248
261
  return self._state_errors
249
262
 
263
+ @property
264
+ def is_shutting_down(self) -> bool:
265
+ """Returns True if the sandbox is shutting down."""
266
+ return self._status == self.Status.SHUTTING_DOWN or (
267
+ self._state_errors and self._status == self.Status.EXITING_SESSION
268
+ )
269
+
250
270
  @property
251
271
  def features(self) -> dict[str, interface.Feature]:
252
272
  """Returns the features in the sandbox."""
253
273
  return self._features
254
274
 
275
+ def _enter_service_call(self) -> bool:
276
+ """Enters a service call.
277
+
278
+ Returns:
279
+ True if the service call is at the top of the call stack.
280
+ """
281
+ v = getattr(self._tls_state, 'service_call_depth', None)
282
+ if v is None:
283
+ v = 0
284
+ setattr(self._tls_state, 'service_call_depth', v + 1)
285
+ return v == 0
286
+
287
+ def _exit_service_call(self) -> bool:
288
+ """Exits a service call.
289
+
290
+ Returns:
291
+ True if the service call is at the top of the call stack.
292
+ """
293
+ v = getattr(self._tls_state, 'service_call_depth')
294
+ setattr(self._tls_state, 'service_call_depth', v - 1)
295
+ return v == 1
296
+
255
297
  #
256
298
  # Sandbox start/shutdown.
257
299
  #
@@ -289,7 +331,7 @@ class BaseSandbox(interface.Sandbox):
289
331
  f'it is in {self._status} status.'
290
332
  )
291
333
 
292
- t = time.time()
334
+ starting_time = time.time()
293
335
  self._state = self.Status.SETTING_UP
294
336
 
295
337
  try:
@@ -314,18 +356,16 @@ class BaseSandbox(interface.Sandbox):
314
356
  # Mark the sandbox as ready when all setup succeeds.
315
357
  self._set_status(self.Status.READY)
316
358
 
317
- self.on_start()
318
- pg.logging.info(
319
- '[%s]: Sandbox started in %.2f seconds.',
320
- self.id, time.time() - t
321
- )
359
+ duration = time.time() - starting_time
360
+ self.on_start(duration)
322
361
  except BaseException as e: # pylint: disable=broad-except
362
+ duration = time.time() - starting_time
323
363
  pg.logging.error(
324
- '[%s]: Sandbox failed to start: %s',
325
- self.id, e
364
+ '[%s]: Sandbox failed to start in %.2f seconds: %s',
365
+ self.id, duration, e
326
366
  )
327
- self._maybe_report_state_error(e)
328
- self.on_start(e)
367
+ self.report_maybe_state_error(e)
368
+ self.on_start(duration, e)
329
369
  self.shutdown()
330
370
  raise e
331
371
 
@@ -379,7 +419,6 @@ class BaseSandbox(interface.Sandbox):
379
419
  return
380
420
 
381
421
  self._set_status(interface.Sandbox.Status.SHUTTING_DOWN)
382
- shutdown_start_time = time.time()
383
422
 
384
423
  if (self._housekeep_thread is not None
385
424
  and threading.current_thread() is not self._housekeep_thread):
@@ -390,20 +429,11 @@ class BaseSandbox(interface.Sandbox):
390
429
  try:
391
430
  self._shutdown()
392
431
  self._set_status(interface.Sandbox.Status.OFFLINE)
393
-
394
- pg.logging.info(
395
- '[%s]: Sandbox shutdown in %.2f seconds. '
396
- '(lifetime: %.2f seconds, teardown errors: %s)',
397
- self.id,
398
- time.time() - shutdown_start_time,
399
- time.time() - self._start_time if self._start_time else 0,
400
- teardown_error
401
- )
402
432
  self.on_shutdown(teardown_error)
403
433
  shutdown_error = None
404
434
  except BaseException as e: # pylint: disable=broad-except
405
435
  shutdown_error = e
406
- self._maybe_report_state_error(e)
436
+ self.report_maybe_state_error(e)
407
437
  self._set_status(interface.Sandbox.Status.OFFLINE)
408
438
  pg.logging.error(
409
439
  '[%s]: Sandbox shutdown with error: %s',
@@ -496,10 +526,12 @@ class BaseSandbox(interface.Sandbox):
496
526
  try:
497
527
  self._start_session()
498
528
  self._set_status(self.Status.IN_SESSION)
499
- self.on_session_start(session_id)
529
+ self.on_session_start(session_id, time.time() - self._session_start_time)
500
530
  except BaseException as e: # pylint: disable=broad-except
501
- self._maybe_report_state_error(e)
502
- self.on_session_start(session_id, e)
531
+ self.report_maybe_state_error(e)
532
+ self.on_session_start(
533
+ session_id, time.time() - self._session_start_time, e
534
+ )
503
535
  self.shutdown()
504
536
  raise e
505
537
 
@@ -507,15 +539,19 @@ class BaseSandbox(interface.Sandbox):
507
539
  """Ends the user session with the sandbox.
508
540
 
509
541
  State transitions:
510
- IN_SESSION -> READY: When user session exits normally, and sandbox is set
511
- to reuse.
512
- IN_SESSION -> SHUTTING_DOWN -> OFFLINE: When user session exits while
542
+ IN_SESSION -> EXITING_SESSION -> READY: When user session exits normally,
543
+ and sandbox is set to reuse.
544
+ IN_SESSION -> EXITING_SESSION -> SHUTTING_DOWN -> OFFLINE: When user
545
+ session exits while
513
546
  sandbox is set not to reuse, or session teardown fails.
514
- IN_SESSION -> SETTING_UP -> READY: When user session exits normally, and
515
- sandbox is set to reuse, and proactive session setup is enabled.
516
- IN_SESSION -> SETTING_UP -> SHUTTING_DOWN -> OFFLINE: When user session
517
- exits normally, and proactive session setup is enabled but fails.
518
- not (IN_SESSION) -> same state: No operation
547
+ IN_SESSION -> EXITING_SESSION -> SETTING_UP -> READY: When user session
548
+ exits normally, and sandbox is set to reuse, and proactive session setup
549
+ is enabled.
550
+ IN_SESSION -> EXITING_SESSION -> SETTING_UP -> SHUTTING_DOWN -> OFFLINE:
551
+ When user session exits normally, and proactive session setup is enabled
552
+ but fails.
553
+ EXITING_SESSION -> EXITING_SESSION: No operation.
554
+ not IN_SESSION -> same state: No operation
519
555
 
520
556
  `end_session` should always be called for each `start_session` call, even
521
557
  when the session fails to start, to ensure proper cleanup.
@@ -541,6 +577,9 @@ class BaseSandbox(interface.Sandbox):
541
577
  Raises:
542
578
  BaseException: If session teardown failed with user-defined errors.
543
579
  """
580
+ if self._status == self.Status.EXITING_SESSION:
581
+ return
582
+
544
583
  if self._status not in (
545
584
  self.Status.IN_SESSION,
546
585
  ):
@@ -549,6 +588,8 @@ class BaseSandbox(interface.Sandbox):
549
588
  assert self._session_id is not None, (
550
589
  'No user session is active for this sandbox'
551
590
  )
591
+ # Set sandbox status to EXITING_SESSION to avoid re-entry.
592
+ self._set_status(self.Status.EXITING_SESSION)
552
593
  shutdown_sandbox = shutdown_sandbox or not self.reusable
553
594
 
554
595
  # Teardown features for the current session.
@@ -572,7 +613,7 @@ class BaseSandbox(interface.Sandbox):
572
613
  self.id,
573
614
  e
574
615
  )
575
- self._maybe_report_state_error(e)
616
+ self.report_maybe_state_error(e)
576
617
  self.shutdown()
577
618
 
578
619
  # End session before setting up the next session.
@@ -602,14 +643,6 @@ class BaseSandbox(interface.Sandbox):
602
643
  self._set_status(interface.Sandbox.Status.ACQUIRED)
603
644
  shutdown_sandbox = True
604
645
 
605
- pg.logging.info(
606
- '[%s]: User session %s ended. '
607
- '(lifetime: %.2f seconds, teardown errors: %s).',
608
- self.id,
609
- self._session_id,
610
- time.time() - self._session_start_time,
611
- end_session_error
612
- )
613
646
  self._session_start_time = None
614
647
  self._session_event_handler = None
615
648
 
@@ -633,8 +666,31 @@ class BaseSandbox(interface.Sandbox):
633
666
  last_ping = now
634
667
  last_housekeep_time = {name: now for name in self._features.keys()}
635
668
 
669
+ def _next_housekeep_wait_time() -> float:
670
+ # Decide how long to sleep for the next housekeeping.
671
+ next_housekeep_time = None
672
+ if self.keepalive_interval is not None:
673
+ next_housekeep_time = last_ping + self.keepalive_interval
674
+
675
+ for name, feature in self._features.items():
676
+ if feature.housekeep_interval is None:
677
+ continue
678
+ next_feature_housekeep_time = (
679
+ last_housekeep_time[name] + feature.housekeep_interval
680
+ )
681
+ if (next_housekeep_time is None
682
+ or next_housekeep_time > next_feature_housekeep_time):
683
+ next_housekeep_time = next_feature_housekeep_time
684
+
685
+ # Housekeep loop is installed when at least one feature requires
686
+ # housekeeping or the sandbox has a keepalive interval.
687
+ assert next_housekeep_time is not None
688
+ return max(0, next_housekeep_time - time.time())
689
+
636
690
  while self._status not in (self.Status.SHUTTING_DOWN, self.Status.OFFLINE):
691
+ housekeep_start = time.time()
637
692
  if self.keepalive_interval is not None:
693
+
638
694
  if time.time() - last_ping > self.keepalive_interval:
639
695
  try:
640
696
  self.ping()
@@ -645,8 +701,9 @@ class BaseSandbox(interface.Sandbox):
645
701
  self.id,
646
702
  str(e)
647
703
  )
648
- self._housekeep_count += 1
649
- self._maybe_report_state_error(e)
704
+ self._housekeep_counter += 1
705
+ self.report_maybe_state_error(e)
706
+ self.on_housekeep(time.time() - housekeep_start, e)
650
707
  self.shutdown()
651
708
  break
652
709
  last_ping = time.time()
@@ -667,20 +724,28 @@ class BaseSandbox(interface.Sandbox):
667
724
  feature.name,
668
725
  e,
669
726
  )
670
- self._maybe_report_state_error(e)
727
+ self.report_maybe_state_error(e)
728
+ self._housekeep_counter += 1
729
+ self.on_housekeep(time.time() - housekeep_start, e)
671
730
  self.shutdown()
672
731
  break
673
- self._housekeep_count += 1
674
- time.sleep(1)
732
+
733
+ self._housekeep_counter += 1
734
+ self.on_housekeep(time.time() - housekeep_start)
735
+ time.sleep(_next_housekeep_wait_time())
675
736
 
676
737
  #
677
738
  # Event handlers subclasses can override.
678
739
  #
679
740
 
680
- def on_start(self, error: BaseException | None = None) -> None:
741
+ def on_start(
742
+ self,
743
+ duration: float,
744
+ error: BaseException | None = None
745
+ ) -> None:
681
746
  """Called when the sandbox is started."""
682
747
  for handler in self._event_handlers:
683
- handler.on_sandbox_start(self.environment, self, error)
748
+ handler.on_sandbox_start(self.environment, self, duration, error)
684
749
 
685
750
  def on_status_change(
686
751
  self,
@@ -690,96 +755,124 @@ class BaseSandbox(interface.Sandbox):
690
755
  """Called when the sandbox status changes."""
691
756
  for handler in self._event_handlers:
692
757
  handler.on_sandbox_status_change(
693
- self.environment, self, old_status, new_status
758
+ self.environment,
759
+ self,
760
+ old_status,
761
+ new_status,
762
+ time.time() - self._status_start_time
694
763
  )
695
764
 
696
765
  def on_shutdown(self, error: BaseException | None = None) -> None:
697
766
  """Called when the sandbox is shutdown."""
767
+ if self._start_time is None:
768
+ lifetime = 0.0
769
+ else:
770
+ lifetime = time.time() - self._start_time
698
771
  for handler in self._event_handlers:
699
- handler.on_sandbox_shutdown(self.environment, self, error)
772
+ handler.on_sandbox_shutdown(self.environment, self, lifetime, error)
773
+
774
+ def on_housekeep(
775
+ self,
776
+ duration: float,
777
+ error: BaseException | None = None
778
+ ) -> None:
779
+ """Called when the sandbox finishes a round of housekeeping."""
780
+ counter = self._housekeep_counter
781
+ for handler in self._event_handlers:
782
+ handler.on_sandbox_housekeep(
783
+ self.environment, self, counter, duration, error
784
+ )
700
785
 
701
786
  def on_feature_setup(
702
787
  self,
703
788
  feature: interface.Feature,
789
+ duration: float,
704
790
  error: BaseException | None = None
705
791
  ) -> None:
706
792
  """Called when a feature is setup."""
707
793
  for handler in self._event_handlers:
708
794
  handler.on_feature_setup(
709
- self.environment, self, feature, error
795
+ self.environment, self, feature, duration, error
710
796
  )
711
797
 
712
798
  def on_feature_teardown(
713
799
  self,
714
800
  feature: interface.Feature,
801
+ duration: float,
715
802
  error: BaseException | None = None
716
803
  ) -> None:
717
804
  """Called when a feature is teardown."""
718
805
  for handler in self._event_handlers:
719
806
  handler.on_feature_teardown(
720
- self.environment, self, feature, error
807
+ self.environment, self, feature, duration, error
721
808
  )
722
809
 
723
810
  def on_feature_setup_session(
724
811
  self,
725
812
  feature: interface.Feature,
813
+ duration: float,
726
814
  error: BaseException | None = None
727
815
  ) -> None:
728
816
  """Called when a feature is setup for a user session."""
729
817
  for handler in self._event_handlers:
730
818
  handler.on_feature_setup_session(
731
- self.environment, self, feature, self.session_id, error
819
+ self.environment, self, feature, self.session_id, duration, error
732
820
  )
733
821
 
734
822
  def on_feature_teardown_session(
735
823
  self,
736
824
  feature: interface.Feature,
825
+ duration: float,
737
826
  error: BaseException | None = None
738
827
  ) -> None:
739
828
  """Called when a feature is teardown for a user session."""
740
829
  for handler in self._event_handlers:
741
830
  handler.on_feature_teardown_session(
742
- self.environment, self, feature, self.session_id, error
831
+ self.environment, self, feature, self.session_id, duration, error
743
832
  )
744
833
 
745
834
  def on_feature_housekeep(
746
835
  self,
747
836
  feature: interface.Feature,
837
+ counter: int,
838
+ duration: float,
748
839
  error: BaseException | None = None
749
840
  ) -> None:
750
841
  """Called when a feature is housekeeping."""
751
842
  for handler in self._event_handlers:
752
843
  handler.on_feature_housekeep(
753
- self.environment, self, feature, error
844
+ self.environment, self, feature, counter, duration, error
754
845
  )
755
846
 
756
847
  def on_session_start(
757
848
  self,
758
849
  session_id: str,
850
+ duration: float,
759
851
  error: BaseException | None = None
760
852
  ) -> None:
761
853
  """Called when the user session starts."""
762
854
  for handler in self._event_handlers:
763
855
  handler.on_session_start(
764
- self.environment, self, session_id, error
856
+ self.environment, self, session_id, duration, error
765
857
  )
766
858
 
767
- def on_session_activity(
859
+ def on_activity(
768
860
  self,
769
- session_id: str,
770
861
  name: str,
862
+ duration: float,
771
863
  feature: interface.Feature | None = None,
772
864
  error: BaseException | None = None,
773
865
  **kwargs
774
866
  ) -> None:
775
867
  """Called when a sandbox activity is performed."""
776
868
  for handler in self._event_handlers:
777
- handler.on_session_activity(
778
- session_id=session_id,
869
+ handler.on_sandbox_activity(
779
870
  name=name,
780
871
  environment=self.environment,
781
872
  sandbox=self,
782
873
  feature=feature,
874
+ session_id=self.session_id,
875
+ duration=duration,
783
876
  error=error,
784
877
  **kwargs
785
878
  )
@@ -790,9 +883,10 @@ class BaseSandbox(interface.Sandbox):
790
883
  error: BaseException | None = None
791
884
  ) -> None:
792
885
  """Called when the user session ends."""
886
+ lifetime = time.time() - self._session_start_time
793
887
  for handler in self._event_handlers:
794
888
  handler.on_session_end(
795
- self.environment, self, session_id, error
889
+ self.environment, self, session_id, lifetime, error
796
890
  )
797
891
 
798
892
 
@@ -876,9 +970,14 @@ def sandbox_service(
876
970
  @functools.wraps(func)
877
971
  def method_wrapper(self, *args, **kwargs) -> Any:
878
972
  """Helper function to safely execute logics in the sandbox."""
973
+
879
974
  assert isinstance(self, (BaseSandbox, interface.Feature)), self
880
975
  sandbox = self.sandbox if isinstance(self, interface.Feature) else self
881
976
 
977
+ # We count the service call stack depth so we could shutdown the sandbox
978
+ # at the top upon sandbox state error.
979
+ sandbox._enter_service_call() # pylint: disable=protected-access
980
+
882
981
  # When a capability is directly accessed from the environment,
883
982
  # we create a new session for the capability call. This
884
983
  # prevents the sandbox from being reused for other feature calls.
@@ -895,70 +994,102 @@ def sandbox_service(
895
994
  new_session = False
896
995
 
897
996
  kwargs.pop('session_id', None)
898
- session_id = sandbox.session_id
899
997
  result = None
900
- state_error = None
901
998
  error = None
999
+ start_time = time.time()
902
1000
 
903
1001
  try:
904
1002
  # Execute the service function.
905
1003
  result = func(self, *args, **kwargs)
906
1004
 
907
- # If the result is a context manager, use it and end the session
908
- # afterwards.
909
- if new_session and isinstance(
910
- result, contextlib.AbstractContextManager
911
- ):
912
- return _end_session_when_exit(result, sandbox)
1005
+ # If the result is a context manager, wrap it with a context manager
1006
+ # to end the session when exiting.
1007
+ if isinstance(result, contextlib.AbstractContextManager):
1008
+ return _service_context_manager_wrapper(
1009
+ service=result,
1010
+ sandbox_or_feature=self,
1011
+ sandbox=sandbox,
1012
+ name=func.__name__,
1013
+ kwargs=to_kwargs(*args, **kwargs),
1014
+ start_time=start_time,
1015
+ new_session=new_session
1016
+ )
913
1017
 
914
1018
  # Otherwise, return the result and end the session in the finally block.
915
1019
  return result
916
- except interface.SandboxStateError as e:
917
- sandbox._maybe_report_state_error(e) # pylint: disable=protected-access
918
- state_error = e
919
- error = e
920
- raise
921
1020
  except BaseException as e:
922
1021
  error = e
1022
+ sandbox.report_maybe_state_error(e)
923
1023
  if pg.match_error(e, critical_errors):
924
1024
  state_error = interface.SandboxStateError(
925
1025
  'Sandbox encountered an unexpected error executing '
926
1026
  f'`{func.__name__}` (args={args!r}, kwargs={kwargs!r}): {e}',
927
1027
  sandbox=self
928
1028
  )
929
- sandbox._maybe_report_state_error(state_error) # pylint: disable=protected-access
1029
+ sandbox.report_maybe_state_error(state_error)
930
1030
  raise state_error from e
931
1031
  raise
932
1032
  finally:
933
- if session_id is not None:
934
- self.on_session_activity(
1033
+ is_topmost_call = sandbox._exit_service_call() # pylint: disable=protected-access
1034
+ if not isinstance(result, contextlib.AbstractContextManager):
1035
+ self.on_activity(
935
1036
  name=func.__name__,
936
- session_id=session_id,
1037
+ duration=time.time() - start_time,
937
1038
  error=error,
938
1039
  **to_kwargs(*args, **kwargs),
939
1040
  )
940
-
941
- if state_error is not None:
1041
+ if new_session:
1042
+ assert is_topmost_call
1043
+
1044
+ # End the session if it's from a feature method and the result
1045
+ # is not a context manager.
1046
+ sandbox.end_session()
1047
+
1048
+ # Shutdown the sandbox if it is at the top of the service call stack and
1049
+ # has state errors.
1050
+ if (is_topmost_call
1051
+ and sandbox.state_errors
1052
+ # Sandbox service method might be called during shutting down, in
1053
+ # that case we don't want to shutdown the sandbox again.
1054
+ and not sandbox.is_shutting_down):
942
1055
  sandbox.shutdown()
943
- elif (new_session
944
- and not isinstance(result, contextlib.AbstractContextManager)):
945
- # End the session if it's from a feature method and the result is not
946
- # a context manager.
947
- sandbox.end_session(
948
- shutdown_sandbox=isinstance(error, interface.SandboxStateError)
949
- )
1056
+
950
1057
  return method_wrapper
951
1058
  return decorator
952
1059
 
953
1060
 
954
1061
  @contextlib.contextmanager
955
- def _end_session_when_exit(
1062
+ def _service_context_manager_wrapper(
956
1063
  service: contextlib.AbstractContextManager[Any],
957
- sandbox: interface.Sandbox
1064
+ sandbox_or_feature: BaseSandbox | interface.Feature,
1065
+ sandbox: interface.Sandbox,
1066
+ name: str,
1067
+ kwargs: dict[str, Any],
1068
+ new_session: bool,
1069
+ start_time: float,
958
1070
  ) -> Iterator[Any]:
959
1071
  """Context manager wrapper for ending a sandbox session when exiting."""
1072
+ error = None
1073
+ sandbox._enter_service_call() # pylint: disable=protected-access
1074
+
960
1075
  try:
961
1076
  with service as result:
962
1077
  yield result
1078
+ except BaseException as e:
1079
+ error = e
1080
+ sandbox.report_maybe_state_error(error)
1081
+ raise
963
1082
  finally:
964
- sandbox.end_session()
1083
+ sandbox_or_feature.on_activity(
1084
+ name=name,
1085
+ error=error,
1086
+ duration=time.time() - start_time,
1087
+ **kwargs,
1088
+ )
1089
+ is_topmost_call = sandbox._exit_service_call() # pylint: disable=protected-access
1090
+
1091
+ if new_session:
1092
+ assert is_topmost_call
1093
+ sandbox.end_session()
1094
+ elif isinstance(error, interface.SandboxStateError):
1095
+ sandbox.shutdown()