langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202511160804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (146) hide show
  1. langfun/core/__init__.py +1 -0
  2. langfun/core/agentic/action.py +107 -12
  3. langfun/core/agentic/action_eval.py +9 -2
  4. langfun/core/agentic/action_test.py +25 -0
  5. langfun/core/async_support.py +32 -3
  6. langfun/core/coding/python/correction.py +19 -9
  7. langfun/core/coding/python/execution.py +14 -12
  8. langfun/core/coding/python/generation.py +21 -16
  9. langfun/core/coding/python/sandboxing.py +23 -3
  10. langfun/core/component.py +42 -3
  11. langfun/core/concurrent.py +70 -6
  12. langfun/core/concurrent_test.py +1 -0
  13. langfun/core/console.py +1 -1
  14. langfun/core/data/conversion/anthropic.py +12 -3
  15. langfun/core/data/conversion/anthropic_test.py +8 -6
  16. langfun/core/data/conversion/gemini.py +9 -2
  17. langfun/core/data/conversion/gemini_test.py +12 -9
  18. langfun/core/data/conversion/openai.py +145 -31
  19. langfun/core/data/conversion/openai_test.py +161 -17
  20. langfun/core/eval/base.py +47 -43
  21. langfun/core/eval/base_test.py +4 -4
  22. langfun/core/eval/matching.py +5 -2
  23. langfun/core/eval/patching.py +3 -3
  24. langfun/core/eval/scoring.py +4 -3
  25. langfun/core/eval/v2/__init__.py +1 -0
  26. langfun/core/eval/v2/checkpointing.py +39 -5
  27. langfun/core/eval/v2/checkpointing_test.py +1 -1
  28. langfun/core/eval/v2/eval_test_helper.py +96 -0
  29. langfun/core/eval/v2/evaluation.py +87 -15
  30. langfun/core/eval/v2/evaluation_test.py +9 -3
  31. langfun/core/eval/v2/example.py +45 -39
  32. langfun/core/eval/v2/example_test.py +3 -3
  33. langfun/core/eval/v2/experiment.py +51 -8
  34. langfun/core/eval/v2/metric_values.py +31 -3
  35. langfun/core/eval/v2/metric_values_test.py +32 -0
  36. langfun/core/eval/v2/metrics.py +157 -44
  37. langfun/core/eval/v2/metrics_test.py +39 -18
  38. langfun/core/eval/v2/progress.py +30 -1
  39. langfun/core/eval/v2/progress_test.py +27 -0
  40. langfun/core/eval/v2/progress_tracking_test.py +3 -0
  41. langfun/core/eval/v2/reporting.py +90 -71
  42. langfun/core/eval/v2/reporting_test.py +20 -6
  43. langfun/core/eval/v2/runners/__init__.py +26 -0
  44. langfun/core/eval/v2/{runners.py → runners/base.py} +22 -124
  45. langfun/core/eval/v2/runners/debug.py +40 -0
  46. langfun/core/eval/v2/runners/debug_test.py +79 -0
  47. langfun/core/eval/v2/runners/parallel.py +100 -0
  48. langfun/core/eval/v2/runners/parallel_test.py +98 -0
  49. langfun/core/eval/v2/runners/sequential.py +47 -0
  50. langfun/core/eval/v2/runners/sequential_test.py +175 -0
  51. langfun/core/langfunc.py +45 -130
  52. langfun/core/langfunc_test.py +6 -4
  53. langfun/core/language_model.py +103 -16
  54. langfun/core/language_model_test.py +9 -3
  55. langfun/core/llms/__init__.py +7 -1
  56. langfun/core/llms/anthropic.py +157 -2
  57. langfun/core/llms/azure_openai.py +29 -17
  58. langfun/core/llms/cache/base.py +25 -3
  59. langfun/core/llms/cache/in_memory.py +48 -7
  60. langfun/core/llms/cache/in_memory_test.py +14 -4
  61. langfun/core/llms/compositional.py +25 -1
  62. langfun/core/llms/deepseek.py +30 -2
  63. langfun/core/llms/fake.py +32 -1
  64. langfun/core/llms/gemini.py +14 -9
  65. langfun/core/llms/google_genai.py +29 -1
  66. langfun/core/llms/groq.py +28 -3
  67. langfun/core/llms/llama_cpp.py +23 -4
  68. langfun/core/llms/openai.py +36 -3
  69. langfun/core/llms/openai_compatible.py +148 -27
  70. langfun/core/llms/openai_compatible_test.py +207 -20
  71. langfun/core/llms/openai_test.py +0 -2
  72. langfun/core/llms/rest.py +12 -1
  73. langfun/core/llms/vertexai.py +51 -8
  74. langfun/core/logging.py +1 -1
  75. langfun/core/mcp/client.py +77 -22
  76. langfun/core/mcp/client_test.py +8 -35
  77. langfun/core/mcp/session.py +94 -29
  78. langfun/core/mcp/session_test.py +54 -0
  79. langfun/core/mcp/tool.py +151 -22
  80. langfun/core/mcp/tool_test.py +197 -0
  81. langfun/core/memory.py +1 -0
  82. langfun/core/message.py +160 -55
  83. langfun/core/message_test.py +65 -81
  84. langfun/core/modalities/__init__.py +8 -0
  85. langfun/core/modalities/audio.py +21 -1
  86. langfun/core/modalities/image.py +19 -1
  87. langfun/core/modalities/mime.py +62 -3
  88. langfun/core/modalities/pdf.py +19 -1
  89. langfun/core/modalities/video.py +21 -1
  90. langfun/core/modality.py +167 -29
  91. langfun/core/modality_test.py +42 -12
  92. langfun/core/natural_language.py +1 -1
  93. langfun/core/sampling.py +4 -4
  94. langfun/core/sampling_test.py +20 -4
  95. langfun/core/structured/__init__.py +2 -24
  96. langfun/core/structured/completion.py +34 -44
  97. langfun/core/structured/completion_test.py +23 -43
  98. langfun/core/structured/description.py +54 -50
  99. langfun/core/structured/function_generation.py +29 -12
  100. langfun/core/structured/mapping.py +81 -37
  101. langfun/core/structured/parsing.py +95 -79
  102. langfun/core/structured/parsing_test.py +0 -3
  103. langfun/core/structured/querying.py +215 -142
  104. langfun/core/structured/querying_test.py +65 -29
  105. langfun/core/structured/schema/__init__.py +48 -0
  106. langfun/core/structured/schema/base.py +664 -0
  107. langfun/core/structured/schema/base_test.py +531 -0
  108. langfun/core/structured/schema/json.py +174 -0
  109. langfun/core/structured/schema/json_test.py +121 -0
  110. langfun/core/structured/schema/python.py +316 -0
  111. langfun/core/structured/schema/python_test.py +410 -0
  112. langfun/core/structured/schema_generation.py +33 -14
  113. langfun/core/structured/scoring.py +47 -36
  114. langfun/core/structured/tokenization.py +26 -11
  115. langfun/core/subscription.py +2 -2
  116. langfun/core/template.py +174 -49
  117. langfun/core/template_test.py +123 -17
  118. langfun/env/__init__.py +8 -2
  119. langfun/env/base_environment.py +320 -128
  120. langfun/env/base_environment_test.py +473 -0
  121. langfun/env/base_feature.py +92 -15
  122. langfun/env/base_feature_test.py +228 -0
  123. langfun/env/base_sandbox.py +84 -361
  124. langfun/env/base_sandbox_test.py +1235 -0
  125. langfun/env/event_handlers/__init__.py +1 -1
  126. langfun/env/event_handlers/chain.py +233 -0
  127. langfun/env/event_handlers/chain_test.py +253 -0
  128. langfun/env/event_handlers/event_logger.py +95 -98
  129. langfun/env/event_handlers/event_logger_test.py +21 -21
  130. langfun/env/event_handlers/metric_writer.py +225 -140
  131. langfun/env/event_handlers/metric_writer_test.py +23 -6
  132. langfun/env/interface.py +854 -40
  133. langfun/env/interface_test.py +112 -2
  134. langfun/env/load_balancers_test.py +23 -2
  135. langfun/env/test_utils.py +126 -84
  136. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/METADATA +1 -1
  137. langfun-0.1.2.dev202511160804.dist-info/RECORD +211 -0
  138. langfun/core/eval/v2/runners_test.py +0 -343
  139. langfun/core/structured/schema.py +0 -987
  140. langfun/core/structured/schema_test.py +0 -982
  141. langfun/env/base_test.py +0 -1481
  142. langfun/env/event_handlers/base.py +0 -350
  143. langfun-0.1.2.dev202510230805.dist-info/RECORD +0 -195
  144. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/WHEEL +0 -0
  145. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/licenses/LICENSE +0 -0
  146. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,473 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import time
15
+ import unittest
16
+
17
+ from langfun.env import interface
18
+ from langfun.env import test_utils
19
+
20
+ TestingEnvironment = test_utils.TestingEnvironment
21
+ TestingSandbox = test_utils.TestingSandbox
22
+ TestingFeature = test_utils.TestingFeature
23
+ TestingEventHandler = test_utils.TestingEventHandler
24
+
25
+
26
+ class BaseEnvironmentTests(unittest.TestCase):
27
+
28
+ def test_basics(self):
29
+ env = TestingEnvironment(
30
+ image_ids=['test_image'],
31
+ root_dir='/tmp',
32
+ pool_size=0,
33
+ features={'test_feature': TestingFeature()},
34
+ outage_grace_period=1,
35
+ outage_retry_interval=0,
36
+ )
37
+ self.assertIsNone(interface.Environment.current())
38
+ self.assertEqual(env.image_ids, ['test_image'])
39
+ self.assertFalse(env.supports_dynamic_image_loading)
40
+ self.assertEqual(env.status, interface.Environment.Status.CREATED)
41
+ self.assertFalse(env.is_online)
42
+ self.assertEqual(env.min_pool_size('test_image'), 0)
43
+ self.assertEqual(env.max_pool_size('test_image'), 0)
44
+ self.assertEqual(env.sandbox_pool, {})
45
+ self.assertEqual(env.id, interface.Environment.Id('testing-env'))
46
+ self.assertEqual(env.outage_grace_period, 1)
47
+ self.assertEqual(env.features['test_feature'].name, 'test_feature')
48
+
49
+ self.assertIsNone(env.start_time)
50
+
51
+ with env:
52
+ self.assertEqual(env.status, interface.Environment.Status.ONLINE)
53
+ self.assertIs(interface.Environment.current(), env)
54
+ self.assertTrue(env.is_online)
55
+ self.assertIsNotNone(env.start_time)
56
+ self.assertEqual(env.offline_duration, 0.0)
57
+ self.assertEqual(env.sandbox_pool, {})
58
+ self.assertEqual(env.working_dir, '/tmp/testing-env')
59
+
60
+ with env.sandbox('session1') as sb:
61
+ self.assertEqual(
62
+ sb.id, interface.Sandbox.Id(
63
+ environment_id=env.id,
64
+ image_id=sb.image_id,
65
+ sandbox_id='0')
66
+ )
67
+ self.assertEqual(sb.session_id, 'session1')
68
+ self.assertEqual(sb.working_dir, '/tmp/testing-env/test_image/0')
69
+ self.assertTrue(sb.is_online)
70
+ self.assertIs(sb.test_feature, sb.features['test_feature'])
71
+ self.assertEqual(
72
+ sb.test_feature.working_dir,
73
+ '/tmp/testing-env/test_image/0/test_feature'
74
+ )
75
+ with self.assertRaises(AttributeError):
76
+ _ = sb.test_feature2
77
+ self.assertFalse(sb.is_online)
78
+
79
+ with self.assertRaisesRegex(
80
+ ValueError, 'Environment .* does not serve image ID .*'
81
+ ):
82
+ env.sandbox(image_id='test_image2')
83
+
84
+ with env.test_feature() as feature:
85
+ self.assertIsInstance(feature, TestingFeature)
86
+ self.assertEqual(
87
+ feature.sandbox.status, interface.Sandbox.Status.IN_SESSION
88
+ )
89
+ self.assertTrue(
90
+ feature.sandbox.session_id.startswith('test_feature-session')
91
+ )
92
+
93
+ with self.assertRaises(AttributeError):
94
+ _ = env.test_feature2
95
+
96
+ def test_dynamic_image_loading(self):
97
+ env = TestingEnvironment(
98
+ image_ids=[],
99
+ supports_dynamic_image_loading=True,
100
+ pool_size=0,
101
+ features={'test_feature': TestingFeature()},
102
+ outage_grace_period=1,
103
+ outage_retry_interval=0,
104
+ )
105
+ with env:
106
+ with env.sandbox(image_id='test_image2') as sb:
107
+ self.assertEqual(sb.image_id, 'test_image2')
108
+
109
+ with self.assertRaisesRegex(
110
+ ValueError, 'Environment .* does not have a default image ID.'
111
+ ):
112
+ env.sandbox()
113
+
114
+ def test_dynamic_image_loading_with_pooling(self):
115
+ env = TestingEnvironment(
116
+ image_ids=[],
117
+ supports_dynamic_image_loading=True,
118
+ pool_size=2,
119
+ features={'test_feature': TestingFeature()},
120
+ outage_grace_period=1,
121
+ outage_retry_interval=0,
122
+ )
123
+ with env:
124
+ with env.sandbox(image_id='test_image'):
125
+ self.assertEqual(len(env.sandbox_pool['test_image']), 1)
126
+
127
+ with env.sandbox(image_id='test_image'):
128
+ self.assertEqual(len(env.sandbox_pool['test_image']), 2)
129
+
130
+ with self.assertRaises(interface.EnvironmentOverloadError):
131
+ with env.sandbox(image_id='test_image'):
132
+ pass
133
+ self.assertEqual(len(env.sandbox_pool['test_image']), 2)
134
+
135
+ with env.sandbox(image_id='test_image'):
136
+ self.assertEqual(len(env.sandbox_pool['test_image']), 2)
137
+
138
+ def test_image_feature_mappings(self):
139
+ env = TestingEnvironment(
140
+ image_ids=[
141
+ 'test_image1',
142
+ 'test_image2',
143
+ ],
144
+ features={
145
+ 'test_feature': TestingFeature(
146
+ applicable_images=['test_image1.*']
147
+ ),
148
+ 'test_feature2': TestingFeature(
149
+ applicable_images=['test_image2.*']
150
+ ),
151
+ 'test_feature3': TestingFeature(
152
+ applicable_images=['test_image.*']
153
+ ),
154
+ },
155
+ pool_size=0,
156
+ outage_grace_period=1,
157
+ outage_retry_interval=0,
158
+ sandbox_keepalive_interval=0,
159
+ )
160
+ with env:
161
+ with env.sandbox(image_id='test_image1') as sb:
162
+ self.assertIn('test_feature', sb.features)
163
+ self.assertNotIn('test_feature2', sb.features)
164
+ self.assertIn('test_feature3', sb.features)
165
+
166
+ with env.sandbox(image_id='test_image2') as sb:
167
+ self.assertNotIn('test_feature', sb.features)
168
+ self.assertIn('test_feature2', sb.features)
169
+ self.assertIn('test_feature3', sb.features)
170
+
171
+ with env.test_feature() as feature:
172
+ self.assertEqual(feature.sandbox.image_id, 'test_image1')
173
+
174
+ with self.assertRaisesRegex(
175
+ ValueError, 'Feature .* is not applicable to .*'
176
+ ):
177
+ with env.test_feature(image_id='test_image2'):
178
+ pass
179
+
180
+ with env.test_feature2() as feature:
181
+ self.assertEqual(feature.sandbox.image_id, 'test_image2')
182
+
183
+ with env.test_feature3() as feature:
184
+ self.assertEqual(feature.sandbox.image_id, 'test_image1')
185
+
186
+ with env.test_feature3(image_id='test_image2') as feature:
187
+ self.assertEqual(feature.sandbox.image_id, 'test_image2')
188
+
189
+ def test_feature_applicability_check(self):
190
+ with self.assertRaisesRegex(
191
+ ValueError, 'Feature .* is not applicable to .*'
192
+ ):
193
+ TestingEnvironment(
194
+ image_ids=[
195
+ 'test_image1',
196
+ ],
197
+ features={
198
+ 'test_feature2': TestingFeature(
199
+ applicable_images=['test_image2.*']
200
+ ),
201
+ },
202
+ )
203
+ env = TestingEnvironment(
204
+ image_ids=[],
205
+ supports_dynamic_image_loading=True,
206
+ features={
207
+ 'test_feature2': TestingFeature(
208
+ applicable_images=['test_image2.*']
209
+ ),
210
+ },
211
+ pool_size=0
212
+ )
213
+ with env:
214
+ with self.assertRaisesRegex(
215
+ ValueError, 'No image ID found for feature .*'
216
+ ):
217
+ with env.test_feature2():
218
+ pass
219
+
220
+ # Dynamically loaded IDs.
221
+ with env.test_feature2(image_id='test_image2') as feature:
222
+ self.assertEqual(feature.sandbox.image_id, 'test_image2')
223
+
224
+ def test_pool_size(self):
225
+ env = TestingEnvironment(
226
+ image_ids=['test_image'],
227
+ pool_size=1,
228
+ outage_grace_period=1,
229
+ outage_retry_interval=0,
230
+ )
231
+ self.assertEqual(env.min_pool_size('test_image'), 1)
232
+ self.assertEqual(env.max_pool_size('test_image'), 1)
233
+
234
+ env = TestingEnvironment(
235
+ image_ids=['test_image'],
236
+ pool_size=(0, 256),
237
+ outage_grace_period=1,
238
+ outage_retry_interval=0,
239
+ )
240
+ self.assertEqual(env.min_pool_size('test_image'), 0)
241
+ self.assertEqual(env.max_pool_size('test_image'), 256)
242
+
243
+ env = TestingEnvironment(
244
+ image_ids=['test_image'],
245
+ pool_size={
246
+ 'test_.*': (0, 128),
247
+ 'my.*': (5, 64),
248
+ 'exact_image_name': 10,
249
+ },
250
+ outage_grace_period=1,
251
+ outage_retry_interval=0,
252
+ )
253
+ self.assertEqual(env.min_pool_size('test_image'), 0)
254
+ self.assertEqual(env.max_pool_size('test_image'), 128)
255
+ self.assertEqual(env.min_pool_size('my_image'), 5)
256
+ self.assertEqual(env.max_pool_size('my_image'), 64)
257
+ self.assertEqual(env.min_pool_size('exact_image_name'), 10)
258
+ self.assertEqual(env.max_pool_size('exact_image_name'), 10)
259
+ self.assertEqual(env.min_pool_size('some_image'), 0) # default
260
+ self.assertEqual(env.max_pool_size('some_image'), 256) # default
261
+
262
+ def test_del(self):
263
+ env = TestingEnvironment(
264
+ features={'test_feature': TestingFeature()},
265
+ pool_size=0,
266
+ outage_grace_period=1,
267
+ outage_retry_interval=0,
268
+ sandbox_keepalive_interval=0,
269
+ )
270
+ env.start()
271
+ sb = env.acquire()
272
+ del sb
273
+ del env
274
+
275
+ def test_acquire_env_offline(self):
276
+ env = TestingEnvironment(
277
+ features={'test_feature': TestingFeature()},
278
+ pool_size=0,
279
+ outage_grace_period=1,
280
+ outage_retry_interval=0,
281
+ sandbox_keepalive_interval=0,
282
+ )
283
+ with self.assertRaises(interface.EnvironmentOutageError):
284
+ env.acquire()
285
+
286
+ def test_acquire_no_pooling(self):
287
+ env = TestingEnvironment(
288
+ features={'test_feature': TestingFeature()},
289
+ pool_size=0,
290
+ outage_grace_period=1,
291
+ outage_retry_interval=0,
292
+ sandbox_keepalive_interval=0,
293
+ )
294
+ with env:
295
+ sb = env.acquire()
296
+ self.assertEqual(sb.status, interface.Sandbox.Status.ACQUIRED)
297
+ self.assertIsNone(env.working_dir)
298
+ self.assertIsNone(sb.working_dir)
299
+ self.assertIsNone(sb.test_feature.working_dir)
300
+
301
+ def test_acquire_no_pooling_with_error(self):
302
+ env = TestingEnvironment(
303
+ features={
304
+ 'test_feature': TestingFeature(
305
+ simulate_setup_error=interface.SandboxStateError
306
+ )
307
+ },
308
+ pool_size=0,
309
+ outage_grace_period=1,
310
+ outage_retry_interval=0,
311
+ sandbox_keepalive_interval=0,
312
+ )
313
+ with env:
314
+ with self.assertRaises(interface.EnvironmentOutageError):
315
+ env.acquire()
316
+
317
+ def test_acquire_with_pooling(self):
318
+ env = TestingEnvironment(
319
+ features={'test_feature': TestingFeature()},
320
+ pool_size=1,
321
+ outage_grace_period=1,
322
+ outage_retry_interval=0,
323
+ sandbox_keepalive_interval=0,
324
+ )
325
+ with env:
326
+ sb = env.acquire()
327
+ self.assertEqual(sb.status, interface.Sandbox.Status.ACQUIRED)
328
+
329
+ def test_acquire_with_pooling_overload(self):
330
+ env = TestingEnvironment(
331
+ features={'test_feature': TestingFeature()},
332
+ pool_size=1,
333
+ outage_grace_period=1,
334
+ outage_retry_interval=0,
335
+ sandbox_keepalive_interval=0,
336
+ )
337
+ with env:
338
+ sb = env.acquire()
339
+ self.assertEqual(sb.status, interface.Sandbox.Status.ACQUIRED)
340
+ with self.assertRaises(interface.EnvironmentOverloadError):
341
+ env.acquire()
342
+
343
+ def test_acquire_with_growing_pool(self):
344
+ env = TestingEnvironment(
345
+ features={'test_feature': TestingFeature()},
346
+ pool_size=(1, 3),
347
+ outage_grace_period=1,
348
+ outage_retry_interval=0,
349
+ sandbox_keepalive_interval=0,
350
+ )
351
+ with env:
352
+ self.assertEqual(len(env.sandbox_pool['test_image']), 1)
353
+ self.assertEqual(
354
+ env.stats(),
355
+ {
356
+ 'sandbox': {
357
+ 'test_image': {
358
+ 'created': 0,
359
+ 'setting_up': 0,
360
+ 'ready': 1,
361
+ 'acquired': 0,
362
+ 'in_session': 0,
363
+ 'exiting_session': 0,
364
+ 'shutting_down': 0,
365
+ 'offline': 0,
366
+ }
367
+ }
368
+ }
369
+ )
370
+ sb = env.acquire()
371
+ self.assertEqual(sb.status, interface.Sandbox.Status.ACQUIRED)
372
+ self.assertEqual(
373
+ env.stats(),
374
+ {
375
+ 'sandbox': {
376
+ 'test_image': {
377
+ 'created': 0,
378
+ 'setting_up': 0,
379
+ 'ready': 0,
380
+ 'acquired': 1,
381
+ 'in_session': 0,
382
+ 'exiting_session': 0,
383
+ 'shutting_down': 0,
384
+ 'offline': 0,
385
+ }
386
+ }
387
+ }
388
+ )
389
+ self.assertEqual(len(env.sandbox_pool['test_image']), 1)
390
+ sb2 = env.acquire()
391
+ self.assertEqual(sb2.status, interface.Sandbox.Status.ACQUIRED)
392
+ self.assertEqual(len(env.sandbox_pool['test_image']), 2)
393
+ self.assertEqual(
394
+ env.stats(),
395
+ {
396
+ 'sandbox': {
397
+ 'test_image': {
398
+ 'created': 0,
399
+ 'setting_up': 0,
400
+ 'ready': 0,
401
+ 'acquired': 2,
402
+ 'in_session': 0,
403
+ 'exiting_session': 0,
404
+ 'shutting_down': 0,
405
+ 'offline': 0,
406
+ }
407
+ }
408
+ }
409
+ )
410
+ self.assertEqual(
411
+ env.stats(),
412
+ {
413
+ 'sandbox': {}
414
+ }
415
+ )
416
+
417
+ def test_acquire_with_growing_pool_failure(self):
418
+ env = TestingEnvironment(
419
+ features={'test_feature': TestingFeature()},
420
+ pool_size=(1, 3),
421
+ outage_grace_period=1,
422
+ outage_retry_interval=0,
423
+ sandbox_keepalive_interval=0,
424
+ )
425
+ with env:
426
+ self.assertEqual(len(env.sandbox_pool), 1)
427
+ sb = env.acquire()
428
+ self.assertEqual(sb.status, interface.Sandbox.Status.ACQUIRED)
429
+
430
+ # Make future sandbox setup to fail.
431
+ env.features.test_feature.rebind(
432
+ simulate_setup_error=interface.SandboxStateError,
433
+ skip_notification=True
434
+ )
435
+ with self.assertRaises(interface.EnvironmentOutageError):
436
+ env.acquire()
437
+
438
+ def test_housekeep_error(self):
439
+ env = TestingEnvironment(
440
+ features={'test_feature': TestingFeature()},
441
+ pool_size=1,
442
+ proactive_session_setup=True,
443
+ outage_grace_period=1,
444
+ outage_retry_interval=0,
445
+ sandbox_keepalive_interval=0,
446
+ )
447
+ with env:
448
+ self.assertEqual(len(env.sandbox_pool), 1)
449
+ self.assertIn('test_image', env.sandbox_pool)
450
+ self.assertEqual(
451
+ env.sandbox_pool['test_image'][0].status,
452
+ interface.Sandbox.Status.READY
453
+ )
454
+ # Make future sandbox setup to fail.
455
+ env.features.test_feature.rebind(
456
+ simulate_setup_error=interface.SandboxStateError,
457
+ skip_notification=True
458
+ )
459
+ with self.assertRaises(interface.SandboxStateError):
460
+ with env.sandbox() as sb:
461
+ sb.shell('bad command', raise_error=interface.SandboxStateError)
462
+ self.assertEqual(sb.status, interface.Sandbox.Status.OFFLINE)
463
+ self.assertEqual(len(sb.state_errors), 1)
464
+ sb_offline_time = time.time()
465
+ while time.time() - sb_offline_time < 10:
466
+ if not env.is_online:
467
+ break
468
+ time.sleep(0.5)
469
+ self.assertFalse(env.is_online)
470
+
471
+
472
+ if __name__ == '__main__':
473
+ unittest.main()
@@ -11,9 +11,9 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Common base class for sandbox-based features.
14
+ """Common base class for environment features.
15
15
 
16
- This module provides an base class `BaseFeature` for sandbox-based features,
16
+ This module provides an base class `BaseFeature` for environment features,
17
17
  which provides event handlers for the feature lifecycle events, which can be
18
18
  overridden by subclasses to provide custom behaviors. Please note that this base
19
19
  class is intended to provide a convenient way to implement features, and not
@@ -22,17 +22,32 @@ coupled with `BaseEnvironment` and `BaseSandbox`, and is expected to work with
22
22
  the `Environment` and `Sandbox` interfaces directly.
23
23
  """
24
24
 
25
+ import contextlib
25
26
  import functools
26
27
  import os
28
+ import re
27
29
  import time
28
- from typing import Annotated, Callable
30
+ from typing import Annotated, Any, Callable, Iterator
29
31
 
30
32
  from langfun.env import interface
31
33
  import pyglove as pg
32
34
 
33
35
 
34
36
  class BaseFeature(interface.Feature):
35
- """Common base class for sandbox-based features."""
37
+ """Common base class for environment features."""
38
+
39
+ is_sandbox_based: Annotated[
40
+ bool,
41
+ 'Whether the feature is sandbox-based.'
42
+ ] = True
43
+
44
+ applicable_images: Annotated[
45
+ list[str],
46
+ (
47
+ 'A list of regular expressions for image IDs which enable '
48
+ 'this feature. By default, all images are enabled.'
49
+ )
50
+ ] = ['.*']
36
51
 
37
52
  housekeep_interval: Annotated[
38
53
  float | None,
@@ -101,10 +116,21 @@ class BaseFeature(interface.Feature):
101
116
  super()._on_parent_change(old_parent, new_parent)
102
117
  self.__dict__.pop('name', None)
103
118
 
119
+ @functools.cached_property
120
+ def environment(self) -> interface.Environment:
121
+ """Returns the environment that the feature is running in."""
122
+ if self._sandbox is not None:
123
+ return self._sandbox.environment
124
+ env = self.sym_ancestor(lambda v: isinstance(v, interface.Environment))
125
+ assert env is not None, 'Feature is not put into an environment.'
126
+ return env
127
+
104
128
  @property
105
- def sandbox(self) -> interface.Sandbox:
129
+ def sandbox(self) -> interface.Sandbox | None:
106
130
  """Returns the sandbox that the feature is running in."""
107
- assert self._sandbox is not None, 'Feature has not been set up yet.'
131
+ assert self._sandbox is not None or not self.is_sandbox_based, (
132
+ 'Feature has not been set up yet.'
133
+ )
108
134
  return self._sandbox
109
135
 
110
136
  @property
@@ -115,6 +141,12 @@ class BaseFeature(interface.Feature):
115
141
  return None
116
142
  return os.path.join(sandbox_workdir, self.name)
117
143
 
144
+ def is_applicable(self, image_id: str) -> bool:
145
+ """Returns True if the feature is applicable to the given image."""
146
+ return any(
147
+ re.fullmatch(regex, image_id) for regex in self.applicable_images
148
+ )
149
+
118
150
  #
119
151
  # Setup and teardown of the feature.
120
152
  #
@@ -135,7 +167,7 @@ class BaseFeature(interface.Feature):
135
167
  finally:
136
168
  event_handler(duration=time.time() - start_time, error=error)
137
169
 
138
- def setup(self, sandbox: interface.Sandbox) -> None:
170
+ def setup(self, sandbox: interface.Sandbox | None = None) -> None:
139
171
  """Sets up the feature."""
140
172
  self._sandbox = sandbox
141
173
  self._do(self._setup, self.on_setup)
@@ -173,7 +205,11 @@ class BaseFeature(interface.Feature):
173
205
  error: BaseException | None = None
174
206
  ) -> None:
175
207
  """Called when the feature is setup."""
176
- self.sandbox.on_feature_setup(self, duration, error)
208
+ self.environment.event_handler.on_feature_setup(
209
+ feature=self,
210
+ duration=duration,
211
+ error=error
212
+ )
177
213
 
178
214
  def on_teardown(
179
215
  self,
@@ -181,7 +217,11 @@ class BaseFeature(interface.Feature):
181
217
  error: BaseException | None = None
182
218
  ) -> None:
183
219
  """Called when the feature is teardown."""
184
- self.sandbox.on_feature_teardown(self, duration, error)
220
+ self.environment.event_handler.on_feature_teardown(
221
+ feature=self,
222
+ duration=duration,
223
+ error=error
224
+ )
185
225
 
186
226
  def on_housekeep(
187
227
  self,
@@ -190,8 +230,12 @@ class BaseFeature(interface.Feature):
190
230
  **kwargs
191
231
  ) -> None:
192
232
  """Called when the feature has done housekeeping."""
193
- self.sandbox.on_feature_housekeep(
194
- self, self._housekeep_counter, duration, error, **kwargs
233
+ self.environment.event_handler.on_feature_housekeep(
234
+ feature=self,
235
+ counter=self._housekeep_counter,
236
+ duration=duration,
237
+ error=error,
238
+ **kwargs
195
239
  )
196
240
 
197
241
  def on_setup_session(
@@ -200,7 +244,12 @@ class BaseFeature(interface.Feature):
200
244
  error: BaseException | None = None,
201
245
  ) -> None:
202
246
  """Called when the feature is setup for a user session."""
203
- self.sandbox.on_feature_setup_session(self, duration, error)
247
+ self.environment.event_handler.on_feature_setup_session(
248
+ feature=self,
249
+ session_id=self.session_id,
250
+ duration=duration,
251
+ error=error
252
+ )
204
253
 
205
254
  def on_teardown_session(
206
255
  self,
@@ -208,7 +257,12 @@ class BaseFeature(interface.Feature):
208
257
  error: BaseException | None = None,
209
258
  ) -> None:
210
259
  """Called when the feature is teardown for a user session."""
211
- self.sandbox.on_feature_teardown_session(self, duration, error)
260
+ self.environment.event_handler.on_feature_teardown_session(
261
+ feature=self,
262
+ session_id=self.session_id,
263
+ duration=duration,
264
+ error=error
265
+ )
212
266
 
213
267
  def on_activity(
214
268
  self,
@@ -218,10 +272,33 @@ class BaseFeature(interface.Feature):
218
272
  **kwargs
219
273
  ) -> None:
220
274
  """Called when a sandbox activity is performed."""
221
- self.sandbox.on_activity(
275
+ self.environment.event_handler.on_feature_activity(
222
276
  name=f'{self.name}.{name}',
223
277
  feature=self,
224
- error=error,
278
+ session_id=self.session_id,
225
279
  duration=duration,
280
+ error=error,
226
281
  **kwargs
227
282
  )
283
+
284
+ @contextlib.contextmanager
285
+ def track_activity(
286
+ self,
287
+ name: str,
288
+ **kwargs: Any
289
+ ) -> Iterator[None]:
290
+ """Context manager that tracks a feature activity."""
291
+ start_time = time.time()
292
+ error = None
293
+ try:
294
+ yield None
295
+ except BaseException as e: # pylint: disable=broad-except
296
+ error = e
297
+ raise
298
+ finally:
299
+ self.on_activity(
300
+ name=name,
301
+ duration=time.time() - start_time,
302
+ error=error,
303
+ **kwargs
304
+ )