langfun 0.1.2.dev202510280805__py3-none-any.whl → 0.1.2.dev202510300805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

@@ -23,8 +23,10 @@ Note that:
23
23
  """
24
24
 
25
25
  import abc
26
+ import collections
26
27
  import functools
27
28
  import random
29
+ import re
28
30
  import threading
29
31
  import time
30
32
  from typing import Annotated, Any
@@ -46,6 +48,23 @@ class BaseEnvironment(interface.Environment):
46
48
  maintenance.
47
49
  """
48
50
 
51
+ image_ids: Annotated[
52
+ list[str],
53
+ (
54
+ 'A list of static image IDs served by the environment. '
55
+ )
56
+ ]
57
+
58
+ supports_dynamic_image_loading: Annotated[
59
+ bool,
60
+ (
61
+ 'Whether the environment supports dynamic loading of images which is '
62
+ 'not included in the `image_ids`. `image_ids` could coexist with '
63
+ 'dynamic image loading, which allows users to specify an image id '
64
+ 'that is not included in the `image_ids`.'
65
+ )
66
+ ] = False
67
+
49
68
  root_dir: Annotated[
50
69
  str | None,
51
70
  (
@@ -55,11 +74,15 @@ class BaseEnvironment(interface.Environment):
55
74
  ] = None
56
75
 
57
76
  pool_size: Annotated[
58
- int | tuple[int, int],
77
+ int | tuple[int, int] | dict[str, int | tuple[int, int]],
59
78
  (
60
79
  'The (min_size, max_size) of the sandbox pool. If an integer, it '
61
- 'will be used as both min and max size. If 0, sandboxes will be '
62
- 'created on demand and shutdown when user session ends.'
80
+ 'will be used as both min and max size. If 0, all sandboxes will be '
81
+ 'created on demand and shutdown when user session ends. If a dict, '
82
+ 'users could configure the pool size based on image IDs. The keys '
83
+ 'are regular expressions for image IDs, and the values are '
84
+ '(min_size, max_size) tuples. For dynamic image IDs, min_size will '
85
+ 'ignored while max_size will be honored.'
63
86
  )
64
87
  ] = (0, 256)
65
88
 
@@ -146,15 +169,36 @@ class BaseEnvironment(interface.Environment):
146
169
 
147
170
  self._status = self.Status.CREATED
148
171
  self._start_time = None
149
- self._sandbox_pool = []
150
- self._next_pooled_sandbox_id = 0
172
+ self._sandbox_pool: dict[str, list[base_sandbox.BaseSandbox]] = (
173
+ collections.defaultdict(list)
174
+ )
175
+ self._next_sandbox_id: dict[str, int] = collections.defaultdict(int)
151
176
  self._random = (
152
177
  random if self.random_seed is None else random.Random(self.random_seed)
153
178
  )
154
-
155
179
  self._housekeep_thread = None
156
180
  self._offline_start_time = None
157
181
 
182
+ # Check image IDs and feature requirements.
183
+ self._check_image_ids()
184
+ self._check_feature_requirements()
185
+
186
+ def _check_image_ids(self) -> None:
187
+ """Checks image ids. Subclass could override this method."""
188
+
189
+ def _check_feature_requirements(self) -> None:
190
+ """Checks if the image ID is supported by the feature."""
191
+ if self.supports_dynamic_image_loading:
192
+ return
193
+ for name, feature in self.features.items():
194
+ if any(feature.is_applicable(image_id) for image_id in self.image_ids):
195
+ continue
196
+ raise ValueError(
197
+ f'Feature {name!r} is not applicable to all available images: '
198
+ f'{self.image_ids!r}. '
199
+ f'Applicable images: {feature.applicable_images}.'
200
+ )
201
+
158
202
  #
159
203
  # Subclasses must implement:
160
204
  #
@@ -162,6 +206,7 @@ class BaseEnvironment(interface.Environment):
162
206
  @abc.abstractmethod
163
207
  def _create_sandbox(
164
208
  self,
209
+ image_id: str,
165
210
  sandbox_id: str,
166
211
  reusable: bool,
167
212
  proactive_session_setup: bool,
@@ -170,6 +215,7 @@ class BaseEnvironment(interface.Environment):
170
215
  """Creates a sandbox with the given identifier.
171
216
 
172
217
  Args:
218
+ image_id: The image ID to use for the sandbox.
173
219
  sandbox_id: The identifier for the sandbox.
174
220
  reusable: Whether the sandbox is reusable across user sessions.
175
221
  proactive_session_setup: Whether the sandbox performs session setup work
@@ -185,13 +231,13 @@ class BaseEnvironment(interface.Environment):
185
231
  interface.SandboxStateError: If sandbox cannot be started.
186
232
  """
187
233
 
188
- def new_session_id(self) -> str:
234
+ def new_session_id(self, feature_hint: str | None = None) -> str:
189
235
  """Generates a random session ID."""
190
236
  suffix = uuid.UUID(
191
237
  bytes=bytes(bytes(self._random.getrandbits(8) for _ in range(16))),
192
238
  version=4
193
239
  ).hex[:7]
194
- return f'session-{suffix}'
240
+ return f'{feature_hint or "unknown"}-session-{suffix}'
195
241
 
196
242
  @property
197
243
  def housekeep_counter(self) -> int:
@@ -204,42 +250,59 @@ class BaseEnvironment(interface.Environment):
204
250
 
205
251
  def stats(self) -> dict[str, Any]:
206
252
  """Returns the stats of the environment."""
207
- stats_dict = {
208
- status.value: 0
209
- for status in interface.Sandbox.Status
210
- }
211
- for sandbox in self._sandbox_pool:
212
- stats_dict[sandbox.status.value] += 1
253
+ stats_by_image_id = {}
254
+ for image_id, sandboxes in self._sandbox_pool.items():
255
+ stats_dict = {
256
+ status.value: 0
257
+ for status in interface.Sandbox.Status
258
+ }
259
+ for sandbox in sandboxes:
260
+ stats_dict[sandbox.status.value] += 1
261
+ stats_by_image_id[image_id] = stats_dict
213
262
  return {
214
- 'sandbox': stats_dict,
263
+ 'sandbox': stats_by_image_id,
215
264
  }
216
265
 
217
266
  def _start(self) -> None:
218
267
  """Implementation of starting the environment."""
219
- if self.min_pool_size > 0:
268
+ sandbox_startup_infos = []
269
+ for image_id in self.image_ids:
270
+ next_sandbox_id = 0
271
+ if self.enable_pooling(image_id):
272
+ min_pool_size = self.min_pool_size(image_id)
273
+ for i in range(min_pool_size):
274
+ sandbox_startup_infos.append((image_id, i))
275
+ self._sandbox_pool[image_id] = [None] * min_pool_size
276
+ next_sandbox_id = min_pool_size
277
+ self._next_sandbox_id[image_id] = next_sandbox_id
278
+
279
+ def _start_sandbox(sandbox_startup_info) -> None:
280
+ image_id, index = sandbox_startup_info
281
+ self._sandbox_pool[image_id][index] = self._bring_up_sandbox_with_retry(
282
+ image_id=image_id,
283
+ sandbox_id=f'{index}:0',
284
+ shutdown_env_upon_outage=False
285
+ )
286
+
287
+ if sandbox_startup_infos:
220
288
  # Pre-allocate the sandbox pool before usage.
221
- self._sandbox_pool = [None] * self.min_pool_size
222
- for i, sandbox, _ in lf.concurrent_map(
223
- lambda i: self._bring_up_sandbox_with_retry(
224
- sandbox_id=f'{i}:0', shutdown_env_upon_outage=False
225
- ),
226
- range(self.min_pool_size),
227
- silence_on_errors=None,
228
- max_workers=min(
229
- self.pool_operation_max_parallelism,
230
- self.min_pool_size
231
- ),
232
- ):
233
- self._sandbox_pool[i] = sandbox
234
-
235
- self._next_sandbox_id = len(self._sandbox_pool)
236
-
237
- if self.enable_pooling:
238
- self._housekeep_thread = threading.Thread(
239
- target=self._housekeep_loop, daemon=True
289
+ _ = list(
290
+ lf.concurrent_map(
291
+ _start_sandbox,
292
+ sandbox_startup_infos,
293
+ silence_on_errors=None,
294
+ max_workers=min(
295
+ self.pool_operation_max_parallelism,
296
+ len(sandbox_startup_infos)
297
+ ),
298
+ )
240
299
  )
241
- self._housekeep_counter = 0
242
- self._housekeep_thread.start()
300
+
301
+ self._housekeep_thread = threading.Thread(
302
+ target=self._housekeep_loop, daemon=True
303
+ )
304
+ self._housekeep_counter = 0
305
+ self._housekeep_thread.start()
243
306
 
244
307
  def _shutdown(self) -> None:
245
308
  """Implementation of shutting down the environment."""
@@ -253,25 +316,30 @@ class BaseEnvironment(interface.Environment):
253
316
  sandbox.shutdown()
254
317
 
255
318
  if self._sandbox_pool:
256
- _ = list(
257
- lf.concurrent_map(
258
- _shutdown_sandbox,
259
- self._sandbox_pool,
260
- silence_on_errors=None,
261
- max_workers=min(
262
- self.pool_operation_max_parallelism,
263
- len(self._sandbox_pool)
264
- ),
265
- )
266
- )
267
- self._sandbox_pool = []
319
+ sandboxes = []
320
+ for sandbox in self._sandbox_pool.values():
321
+ sandboxes.extend(sandbox)
322
+ self._sandbox_pool = {}
323
+
324
+ if sandboxes:
325
+ _ = list(
326
+ lf.concurrent_map(
327
+ _shutdown_sandbox,
328
+ sandboxes,
329
+ silence_on_errors=None,
330
+ max_workers=min(
331
+ self.pool_operation_max_parallelism,
332
+ len(sandboxes)
333
+ ),
334
+ )
335
+ )
268
336
 
269
337
  #
270
338
  # Environment basics.
271
339
  #
272
340
 
273
341
  @property
274
- def sandbox_pool(self) -> list[base_sandbox.BaseSandbox]:
342
+ def sandbox_pool(self) -> dict[str, list[base_sandbox.BaseSandbox]]:
275
343
  """Returns the sandbox pool."""
276
344
  return self._sandbox_pool
277
345
 
@@ -280,11 +348,6 @@ class BaseEnvironment(interface.Environment):
280
348
  """Returns the working directory for the environment."""
281
349
  return self.id.working_dir(self.root_dir)
282
350
 
283
- @property
284
- def enable_pooling(self) -> bool:
285
- """Returns whether the environment enables pooling."""
286
- return self.max_pool_size > 0
287
-
288
351
  @property
289
352
  def status(self) -> interface.Environment.Status:
290
353
  """Returns whether the environment is online."""
@@ -294,19 +357,39 @@ class BaseEnvironment(interface.Environment):
294
357
  """Sets the status of the environment."""
295
358
  self._status = status
296
359
 
297
- @property
298
- def min_pool_size(self) -> int:
360
+ def enable_pooling(self, image_id: str) -> bool:
361
+ """Returns whether the environment enables pooling."""
362
+ return self.max_pool_size(image_id) > 0
363
+
364
+ def min_pool_size(self, image_id: str) -> int:
299
365
  """Returns the minimum size of the sandbox pool."""
300
- if isinstance(self.pool_size, int):
301
- return self.pool_size
302
- return self.pool_size[0]
366
+ return self._pool_size(image_id)[0]
303
367
 
304
- @property
305
- def max_pool_size(self) -> int:
368
+ def max_pool_size(self, image_id: str) -> int:
306
369
  """Returns the maximum size of the sandbox pool."""
307
- if isinstance(self.pool_size, int):
308
- return self.pool_size
309
- return self.pool_size[1]
370
+ return self._pool_size(image_id)[1]
371
+
372
+ def _pool_size(self, image_id: str) -> tuple[int, int]:
373
+ """Returns the minimum and maximum size of the sandbox pool."""
374
+ if isinstance(self.pool_size, dict):
375
+ if image_id in self.pool_size:
376
+ pool_size = self.pool_size[image_id]
377
+ else:
378
+ for k, v in self.pool_size.items():
379
+ if re.match(k, image_id):
380
+ pool_size = v
381
+ break
382
+ else:
383
+ # Default pool size is 0 and 256.
384
+ pool_size = (0, 256)
385
+ else:
386
+ pool_size = self.pool_size
387
+
388
+ if isinstance(pool_size, int):
389
+ return pool_size, pool_size
390
+ else:
391
+ assert isinstance(pool_size, tuple) and len(pool_size) == 2
392
+ return pool_size
310
393
 
311
394
  @property
312
395
  def start_time(self) -> float | None:
@@ -373,9 +456,16 @@ class BaseEnvironment(interface.Environment):
373
456
  # Environment operations.
374
457
  #
375
458
 
376
- def acquire(self) -> base_sandbox.BaseSandbox:
459
+ def acquire(
460
+ self,
461
+ image_id: str | None = None
462
+ ) -> base_sandbox.BaseSandbox:
377
463
  """Acquires a sandbox from the environment.
378
464
 
465
+ Args:
466
+ image_id: The image ID to use for the sandbox. If None, it will be
467
+ automatically determined by the environment.
468
+
379
469
  Returns:
380
470
  The acquired sandbox.
381
471
 
@@ -385,28 +475,50 @@ class BaseEnvironment(interface.Environment):
385
475
  interface.EnvironmentOverloadError: If the max pool size is reached and
386
476
  the grace period has passed.
387
477
  """
388
-
389
478
  if not self.is_online:
390
479
  raise interface.EnvironmentOutageError(
391
480
  f'Environment {self.id} is not alive.',
392
481
  environment=self,
393
482
  offline_duration=self.offline_duration,
394
483
  )
484
+ if image_id is None:
485
+ if not self.image_ids:
486
+ raise ValueError(
487
+ f'Environment {self.id} does not have a default image ID. '
488
+ 'Please specify the image ID explicitly.'
489
+ )
490
+ image_id = self.image_ids[0]
491
+ elif (image_id not in self.image_ids
492
+ and not self.supports_dynamic_image_loading):
493
+ raise ValueError(
494
+ f'Environment {self.id} does not serve image ID {image_id!r}. '
495
+ f'Please use one of the following image IDs: {self.image_ids!r} or '
496
+ f'set `{self.__class__.__name__}.supports_dynamic_image_ids` '
497
+ 'to True if dynamic image loading is supported.'
498
+ )
499
+ return self._acquire(image_id)
395
500
 
396
- if not self.enable_pooling:
501
+ def _acquire(
502
+ self,
503
+ image_id: str | None = None
504
+ ) -> base_sandbox.BaseSandbox:
505
+ """Acquires a sandbox from the environment."""
506
+ if not self.enable_pooling(image_id):
397
507
  return self._bring_up_sandbox_with_retry(
398
- sandbox_id=str(self._increment_sandbox_id()),
508
+ image_id=image_id,
509
+ sandbox_id=str(self._increment_sandbox_id(image_id)),
399
510
  set_acquired=True,
400
511
  )
401
512
 
402
513
  allocation_start_time = time.time()
514
+ sandbox_pool = self._sandbox_pool[image_id]
403
515
  while True:
404
516
  try:
405
517
  # We only append or replace items in the sandbox pool, therefore
406
518
  # there is no need to lock the pool.
407
- return self.load_balancer.acquire(self._sandbox_pool)
519
+ return self.load_balancer.acquire(sandbox_pool)
408
520
  except IndexError:
409
- if len(self._sandbox_pool) == self.max_pool_size:
521
+ if len(sandbox_pool) == self.max_pool_size(image_id):
410
522
  if time.time() - allocation_start_time > self.outage_grace_period:
411
523
  raise interface.EnvironmentOverloadError( # pylint: disable=raise-missing-from
412
524
  environment=self
@@ -415,11 +527,12 @@ class BaseEnvironment(interface.Environment):
415
527
  else:
416
528
  try:
417
529
  sandbox = self._bring_up_sandbox(
418
- sandbox_id=f'{self._increment_sandbox_id()}:0',
530
+ image_id=image_id,
531
+ sandbox_id=f'{self._increment_sandbox_id(image_id)}:0',
419
532
  set_acquired=True,
420
533
  )
421
534
  # Append is atomic and does not require locking.
422
- self._sandbox_pool.append(sandbox)
535
+ sandbox_pool.append(sandbox)
423
536
  return sandbox
424
537
  except (
425
538
  interface.EnvironmentError, interface.SandboxStateError
@@ -428,6 +541,7 @@ class BaseEnvironment(interface.Environment):
428
541
 
429
542
  def _bring_up_sandbox(
430
543
  self,
544
+ image_id: str,
431
545
  sandbox_id: str,
432
546
  set_acquired: bool = False,
433
547
  ) -> base_sandbox.BaseSandbox:
@@ -435,8 +549,9 @@ class BaseEnvironment(interface.Environment):
435
549
  env_error = None
436
550
  try:
437
551
  sandbox = self._create_sandbox(
552
+ image_id=image_id,
438
553
  sandbox_id=sandbox_id,
439
- reusable=self.enable_pooling,
554
+ reusable=self.enable_pooling(image_id),
440
555
  proactive_session_setup=self.proactive_session_setup,
441
556
  keepalive_interval=self.sandbox_keepalive_interval,
442
557
  )
@@ -457,6 +572,7 @@ class BaseEnvironment(interface.Environment):
457
572
 
458
573
  def _bring_up_sandbox_with_retry(
459
574
  self,
575
+ image_id: str,
460
576
  sandbox_id: str,
461
577
  set_acquired: bool = False,
462
578
  shutdown_env_upon_outage: bool = True,
@@ -464,6 +580,7 @@ class BaseEnvironment(interface.Environment):
464
580
  """Brings up a new sandbox with retry until grace period is passed.
465
581
 
466
582
  Args:
583
+ image_id: The image ID to use for the sandbox.
467
584
  sandbox_id: The ID of the sandbox to bring up.
468
585
  set_acquired: If True, the sandbox will be marked as acquired.
469
586
  shutdown_env_upon_outage: Whether to shutdown the environment when the
@@ -479,15 +596,15 @@ class BaseEnvironment(interface.Environment):
479
596
  while True:
480
597
  try:
481
598
  return self._bring_up_sandbox(
482
- sandbox_id=sandbox_id, set_acquired=set_acquired
599
+ image_id=image_id, sandbox_id=sandbox_id, set_acquired=set_acquired
483
600
  )
484
601
  except (interface.EnvironmentError, interface.SandboxStateError) as e:
485
602
  self._report_outage_or_wait(e, shutdown_env_upon_outage)
486
603
 
487
- def _increment_sandbox_id(self) -> int:
604
+ def _increment_sandbox_id(self, image_id: str) -> int:
488
605
  """Returns the next pooled sandbox ID."""
489
- x = self._next_sandbox_id
490
- self._next_sandbox_id += 1
606
+ x = self._next_sandbox_id[image_id]
607
+ self._next_sandbox_id[image_id] += 1
491
608
  return x
492
609
 
493
610
  def _report_outage_or_wait(
@@ -511,26 +628,39 @@ class BaseEnvironment(interface.Environment):
511
628
 
512
629
  def _housekeep_loop(self) -> None:
513
630
  """Housekeeping loop for the environment."""
631
+ def _indices_by_image_id(
632
+ entries: list[tuple[str, int, Any]]
633
+ ) -> dict[str, list[int]]:
634
+ indices_by_image_id = collections.defaultdict(list)
635
+ for image_id, i, _ in entries:
636
+ indices_by_image_id[image_id].append(i)
637
+ return indices_by_image_id
638
+
514
639
  while self._status not in (self.Status.SHUTTING_DOWN, self.Status.OFFLINE):
515
640
  housekeep_start_time = time.time()
516
641
 
517
642
  is_online = True
518
- dead_pool_indices = [
519
- i for i, s in enumerate(self._sandbox_pool)
520
- if s.status == interface.Sandbox.Status.OFFLINE
521
- ]
522
- replaced_indices = []
523
-
524
- if dead_pool_indices:
525
- replaced_indices = self._replace_dead_sandboxes(dead_pool_indices)
526
- if not replaced_indices:
643
+ dead_sandbox_entries = []
644
+ for image_id, sandboxes in self._sandbox_pool.items():
645
+ for i, sandbox in enumerate(sandboxes):
646
+ if sandbox.status == interface.Sandbox.Status.OFFLINE:
647
+ dead_sandbox_entries.append((image_id, i, sandbox))
648
+
649
+ replaced_indices_by_image_id = {}
650
+
651
+ if dead_sandbox_entries:
652
+ replaced_indices_by_image_id = self._replace_dead_sandboxes(
653
+ dead_sandbox_entries
654
+ )
655
+ if not replaced_indices_by_image_id:
527
656
  is_online = self.offline_duration < self.outage_grace_period
528
657
 
529
658
  self._housekeep_counter += 1
530
659
  duration = time.time() - housekeep_start_time
660
+
531
661
  kwargs = dict(
532
- dead_pool_indices=dead_pool_indices,
533
- replaced_indices=replaced_indices,
662
+ dead_sandboxes=_indices_by_image_id(dead_sandbox_entries),
663
+ replaced_sandboxes=replaced_indices_by_image_id,
534
664
  offline_duration=self.offline_duration,
535
665
  )
536
666
  if is_online:
@@ -546,49 +676,61 @@ class BaseEnvironment(interface.Environment):
546
676
  **kwargs
547
677
  )
548
678
 
549
- def _replace_dead_sandboxes(self, dead_pool_indices: list[int]) -> list[int]:
679
+ def _replace_dead_sandboxes(
680
+ self,
681
+ dead_sandbox_entries: list[tuple[str, int, base_sandbox.BaseSandbox]]
682
+ ) -> dict[str, list[int]]:
550
683
  """Replaces a dead sandbox with a new one.
551
684
 
552
685
  Args:
553
- dead_pool_indices: The indices of the dead sandboxes to replace.
686
+ dead_sandbox_entries: A list of tuples (image_id, index, sandbox) of
687
+ dead sandboxes to replace.
554
688
 
555
689
  Returns:
556
- Successfully replaced indices.
690
+ Successfully replaced sandboxes in a dict of image ID to a list of
691
+ indices.
557
692
  """
558
693
  pg.logging.warning(
559
694
  '[%s]: %s maintenance: '
560
695
  'Replacing %d dead sandbox(es) with new ones...',
561
696
  self.id,
562
697
  self.__class__.__name__,
563
- len(dead_pool_indices),
698
+ len(dead_sandbox_entries),
564
699
  )
565
- def _replace(i: int):
566
- generation = int(self._sandbox_pool[i].id.sandbox_id.split(':')[1])
567
- self._sandbox_pool[i] = self._bring_up_sandbox(f'{i}:{generation + 1}')
700
+ def _replace(sandbox_entry: tuple[str, int, base_sandbox.BaseSandbox]):
701
+ image_id, i, sandbox = sandbox_entry
702
+ generation = int(sandbox.id.sandbox_id.split(':')[-1])
703
+ replaced_sandbox = self._bring_up_sandbox(
704
+ image_id=image_id,
705
+ sandbox_id=f'{i}:{generation + 1}'
706
+ )
707
+ self._sandbox_pool[image_id][i] = replaced_sandbox
568
708
 
569
709
  # TODO(daiyip): Consider to loose the condition to allow some dead
570
710
  # sandboxes to be replaced successfully.
571
- replaced_indices = []
572
- for index, _, error in lf.concurrent_map(
573
- _replace, dead_pool_indices,
711
+ replaced_indices_by_image_id = collections.defaultdict(list)
712
+ num_replaced = 0
713
+ for (image_id, index, _), _, error in lf.concurrent_map(
714
+ _replace, dead_sandbox_entries,
574
715
  max_workers=min(
575
716
  self.pool_operation_max_parallelism,
576
- len(dead_pool_indices)
717
+ len(dead_sandbox_entries)
577
718
  ),
578
719
  ):
579
720
  if error is None:
580
- replaced_indices.append(index)
721
+ replaced_indices_by_image_id[image_id].append(index)
722
+ num_replaced += 1
581
723
 
582
724
  pg.logging.warning(
583
725
  '[%s]: %s maintenance: '
584
726
  '%d/%d dead sandbox(es) have been replaced with new ones. (slots=%s)',
585
727
  self.id,
586
728
  self.__class__.__name__,
587
- len(replaced_indices),
588
- len(dead_pool_indices),
589
- replaced_indices
729
+ num_replaced,
730
+ len(dead_sandbox_entries),
731
+ replaced_indices_by_image_id,
590
732
  )
591
- return replaced_indices
733
+ return replaced_indices_by_image_id
592
734
 
593
735
  #
594
736
  # Event handlers subclasses can override.
@@ -24,6 +24,7 @@ the `Environment` and `Sandbox` interfaces directly.
24
24
 
25
25
  import functools
26
26
  import os
27
+ import re
27
28
  import time
28
29
  from typing import Annotated, Callable
29
30
 
@@ -34,6 +35,14 @@ import pyglove as pg
34
35
  class BaseFeature(interface.Feature):
35
36
  """Common base class for sandbox-based features."""
36
37
 
38
+ applicable_images: Annotated[
39
+ list[str],
40
+ (
41
+ 'A list of regular expressions for image IDs which enable '
42
+ 'this feature. By default, all images are enabled.'
43
+ )
44
+ ] = ['.*']
45
+
37
46
  housekeep_interval: Annotated[
38
47
  float | None,
39
48
  'Interval in seconds for feature housekeeping.'
@@ -115,6 +124,12 @@ class BaseFeature(interface.Feature):
115
124
  return None
116
125
  return os.path.join(sandbox_workdir, self.name)
117
126
 
127
+ def is_applicable(self, image_id: str) -> bool:
128
+ """Returns True if the feature is applicable to the given image."""
129
+ return any(
130
+ re.fullmatch(regex, image_id) for regex in self.applicable_images
131
+ )
132
+
118
133
  #
119
134
  # Setup and teardown of the feature.
120
135
  #