continual-foragax 0.30.1__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.30.1
3
+ Version: 0.31.0
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,8 +1,8 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=4NZ5JsUGjAepmzw2uxu5_ikyVZnZ7vazy062Xzx22Zg,27481
4
- foragax/objects.py,sha256=M0nECANGfUvvBRMKSS7akGtoO2Suv5eroI-9Aj326sw,10368
5
- foragax/registry.py,sha256=Dxg6cWIPwg91fNrCPxADJv35u6jFg_8dI5iTpCMFEFA,15229
3
+ foragax/env.py,sha256=j7u8Xv8Fz4-GZU0SXXUP40Nz4NCvRrX02h9gl8PZMK4,32456
4
+ foragax/objects.py,sha256=FBO0k4X6tKidWWsk6E-GotA9jXYXc8GXjZxNQJItvWI,14119
5
+ foragax/registry.py,sha256=zF3RtW8ssWvnDKS31xJ_Ac9hP5sv8DWy4xBu4uCbHpQ,16060
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
7
7
  foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
8
8
  foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.30.1.dist-info/METADATA,sha256=9iwHDGT1ZbvjL_CNRFQRQsPBPWKTFBetxrJPk_OXKug,4897
132
- continual_foragax-0.30.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.30.1.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.30.1.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.30.1.dist-info/RECORD,,
131
+ continual_foragax-0.31.0.dist-info/METADATA,sha256=o5l0-uCkTUx4jEV9qc49cOagYlRcksIBUkumbbPde8U,4897
132
+ continual_foragax-0.31.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.31.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.31.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.31.0.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -62,6 +62,7 @@ class EnvState(environment.EnvState):
62
62
  biome_grid: jax.Array
63
63
  time: int
64
64
  digestion_buffer: jax.Array
65
+ object_spawn_time_grid: jax.Array
65
66
 
66
67
 
67
68
  class ForagaxEnv(environment.Environment):
@@ -114,6 +115,12 @@ class ForagaxEnv(environment.Environment):
114
115
  self.reward_fns = [o.reward for o in objects]
115
116
  self.regen_delay_fns = [o.regen_delay for o in objects]
116
117
  self.reward_delay_fns = [o.reward_delay for o in objects]
118
+ self.expiry_regen_delay_fns = [o.expiry_regen_delay for o in objects]
119
+
120
+ # Expiry times per object (None becomes -1 for no expiry)
121
+ self.object_expiry_time = jnp.array(
122
+ [o.expiry_time if o.expiry_time is not None else -1 for o in objects]
123
+ )
117
124
 
118
125
  # Compute reward steps per object (using max_reward_delay attribute)
119
126
  object_max_reward_delay = jnp.array([o.max_reward_delay for o in objects])
@@ -193,6 +200,47 @@ class ForagaxEnv(environment.Environment):
193
200
  max_steps_in_episode=None,
194
201
  )
195
202
 
203
+ def _place_timer_at_position(
204
+ self, grid: jax.Array, y: int, x: int, timer_val: int
205
+ ) -> jax.Array:
206
+ """Place a timer at a specific position."""
207
+ return grid.at[y, x].set(timer_val)
208
+
209
+ def _place_timer_at_random_position(
210
+ self,
211
+ grid: jax.Array,
212
+ y: int,
213
+ x: int,
214
+ timer_val: int,
215
+ biome_grid: jax.Array,
216
+ rand_key: jax.Array,
217
+ ) -> jax.Array:
218
+ """Place a timer at a random valid position within the same biome."""
219
+ # Set the original position to empty temporarily
220
+ grid_temp = grid.at[y, x].set(0)
221
+
222
+ # Find all valid spawn locations (empty cells within the same biome)
223
+ biome_id = biome_grid[y, x]
224
+ biome_mask = biome_grid == biome_id
225
+ empty_mask = grid_temp == 0
226
+ valid_spawn_mask = biome_mask & empty_mask
227
+
228
+ num_valid_spawns = jnp.sum(valid_spawn_mask)
229
+
230
+ # Get indices of valid spawn locations, padded to a static size
231
+ y_indices, x_indices = jnp.nonzero(
232
+ valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
233
+ )
234
+ valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
235
+
236
+ # Select a random valid location
237
+ random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
238
+ new_spawn_pos = valid_spawn_indices[random_idx]
239
+
240
+ # Place the timer at the new random position
241
+ new_grid = grid_temp.at[new_spawn_pos[0], new_spawn_pos[1]].set(timer_val)
242
+ return new_grid
243
+
196
244
  def step_env(
197
245
  self,
198
246
  key: jax.Array,
@@ -277,54 +325,129 @@ class ForagaxEnv(environment.Environment):
277
325
  is_timer, state.object_grid + num_obj_types, state.object_grid
278
326
  )
279
327
 
328
+ # Track which objects just respawned (timer -> object transition)
329
+ # An object respawned if it was negative and is now positive
330
+ just_respawned = is_timer & (object_grid > 0)
331
+
332
+ # Update spawn times for objects that just respawned
333
+ object_spawn_time_grid = jnp.where(
334
+ just_respawned, state.time, state.object_spawn_time_grid
335
+ )
336
+
280
337
  # Collect object: set a timer
281
338
  regen_delay = jax.lax.switch(
282
339
  obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
283
340
  )
284
341
  encoded_timer = obj_at_pos - ((regen_delay + 1) * num_obj_types)
285
342
 
286
- def place_at_current_pos(current_grid, timer_val):
287
- return current_grid.at[pos[1], pos[0]].set(timer_val)
288
-
289
- def place_at_random_pos(current_grid, timer_val):
290
- # Set the collected position to empty temporarily
291
- grid = current_grid.at[pos[1], pos[0]].set(0)
292
-
293
- # Find all valid spawn locations (empty cells within the same biome)
294
- biome_id = state.biome_grid[pos[1], pos[0]]
295
- biome_mask = state.biome_grid == biome_id
296
- empty_mask = grid == 0
297
- valid_spawn_mask = biome_mask & empty_mask
298
-
299
- num_valid_spawns = jnp.sum(valid_spawn_mask)
300
-
301
- # Get indices of valid spawn locations, padded to a static size
302
- y_indices, x_indices = jnp.nonzero(
303
- valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
304
- )
305
- valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
306
-
307
- # Select a random valid location
308
- random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
309
- new_spawn_pos = valid_spawn_indices[random_idx]
310
-
311
- # Place the timer at the new random position
312
- return grid.at[new_spawn_pos[0], new_spawn_pos[1]].set(timer_val)
313
-
314
343
  # If collected, replace object with timer; otherwise, keep it
315
344
  val_at_pos = object_grid[pos[1], pos[0]]
316
345
  should_collect = is_collectable & (val_at_pos > 0)
317
346
 
318
347
  # When not collecting, the value at the position remains unchanged.
319
348
  # When collecting, we either place the timer at the current position or a random one.
349
+ def do_collection():
350
+ return jax.lax.cond(
351
+ self.object_random_respawn[obj_at_pos],
352
+ lambda: self._place_timer_at_random_position(
353
+ object_grid,
354
+ pos[1],
355
+ pos[0],
356
+ encoded_timer,
357
+ state.biome_grid,
358
+ rand_key,
359
+ ),
360
+ lambda: self._place_timer_at_position(
361
+ object_grid, pos[1], pos[0], encoded_timer
362
+ ),
363
+ )
364
+
365
+ def no_collection():
366
+ return object_grid
367
+
320
368
  object_grid = jax.lax.cond(
321
369
  should_collect,
322
- lambda: jax.lax.cond(
323
- self.object_random_respawn[obj_at_pos],
324
- lambda: place_at_random_pos(object_grid, encoded_timer),
325
- lambda: place_at_current_pos(object_grid, encoded_timer),
326
- ),
327
- lambda: object_grid,
370
+ do_collection,
371
+ no_collection,
372
+ )
373
+
374
+ # 3.5. HANDLE OBJECT EXPIRY
375
+ # Check each cell for objects that have exceeded their expiry time
376
+ current_objects_for_expiry = jnp.maximum(0, object_grid)
377
+
378
+ # Calculate age of each object (current_time - spawn_time)
379
+ object_ages = state.time - object_spawn_time_grid
380
+
381
+ # Get expiry time for each object type in the grid
382
+ expiry_times = self.object_expiry_time[current_objects_for_expiry]
383
+
384
+ # Check if object should expire (age >= expiry_time and expiry_time >= 0)
385
+ should_expire = (
386
+ (object_ages >= expiry_times)
387
+ & (expiry_times >= 0)
388
+ & (current_objects_for_expiry > 0)
389
+ )
390
+
391
+ # For expired objects, calculate expiry regen delay
392
+ key, expiry_key = jax.random.split(key)
393
+
394
+ # Process expiry for all positions that need it
395
+ def process_expiry(y, x, grid, spawn_grid, key):
396
+ obj_id = current_objects_for_expiry[y, x]
397
+ should_exp = should_expire[y, x]
398
+
399
+ def expire_object():
400
+ # Get expiry regen delay for this object
401
+ exp_key = jax.random.fold_in(key, y * self.size[0] + x)
402
+ exp_delay = jax.lax.switch(
403
+ obj_id, self.expiry_regen_delay_fns, state.time, exp_key
404
+ )
405
+ encoded_exp_timer = obj_id - ((exp_delay + 1) * num_obj_types)
406
+
407
+ # Check if this object should respawn randomly
408
+ should_random_respawn = self.object_random_respawn[obj_id]
409
+
410
+ # Use second split for randomness in random placement
411
+ rand_key = jax.random.split(exp_key)[1]
412
+
413
+ # Place timer either at current position or random position
414
+ new_grid = jax.lax.cond(
415
+ should_random_respawn,
416
+ lambda: self._place_timer_at_random_position(
417
+ grid, y, x, encoded_exp_timer, state.biome_grid, rand_key
418
+ ),
419
+ lambda: self._place_timer_at_position(
420
+ grid, y, x, encoded_exp_timer
421
+ ),
422
+ )
423
+
424
+ return new_grid, spawn_grid
425
+
426
+ def no_expire():
427
+ return grid, spawn_grid
428
+
429
+ return jax.lax.cond(should_exp, expire_object, no_expire)
430
+
431
+ # Apply expiry to all cells (vectorized)
432
+ def scan_expiry_row(carry, y):
433
+ grid, spawn_grid, key = carry
434
+
435
+ def scan_expiry_col(carry_col, x):
436
+ grid_col, spawn_grid_col, key_col = carry_col
437
+ grid_col, spawn_grid_col = process_expiry(
438
+ y, x, grid_col, spawn_grid_col, key_col
439
+ )
440
+ return (grid_col, spawn_grid_col, key_col), None
441
+
442
+ (grid, spawn_grid, key), _ = jax.lax.scan(
443
+ scan_expiry_col, (grid, spawn_grid, key), jnp.arange(self.size[0])
444
+ )
445
+ return (grid, spawn_grid, key), None
446
+
447
+ (object_grid, object_spawn_time_grid, _), _ = jax.lax.scan(
448
+ scan_expiry_row,
449
+ (object_grid, object_spawn_time_grid, expiry_key),
450
+ jnp.arange(self.size[1]),
328
451
  )
329
452
 
330
453
  info = {"discount": self.discount(state, params)}
@@ -345,6 +468,7 @@ class ForagaxEnv(environment.Environment):
345
468
  biome_grid=state.biome_grid,
346
469
  time=state.time + 1,
347
470
  digestion_buffer=digestion_buffer,
471
+ object_spawn_time_grid=object_spawn_time_grid,
348
472
  )
349
473
 
350
474
  done = self.is_terminal(state, params)
@@ -378,12 +502,16 @@ class ForagaxEnv(environment.Environment):
378
502
  # Place agent in the center of the world
379
503
  agent_pos = jnp.array([self.size[0] // 2, self.size[1] // 2])
380
504
 
505
+ # Initialize spawn times to 0 (all objects spawn at time 0)
506
+ object_spawn_time_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
507
+
381
508
  state = EnvState(
382
509
  pos=agent_pos,
383
510
  object_grid=object_grid,
384
511
  biome_grid=biome_grid,
385
512
  time=0,
386
513
  digestion_buffer=jnp.zeros((self.max_reward_delay,)),
514
+ object_spawn_time_grid=object_spawn_time_grid,
387
515
  )
388
516
 
389
517
  return self.get_obs(state, params), state
@@ -450,6 +578,12 @@ class ForagaxEnv(environment.Environment):
450
578
  (self.max_reward_delay,),
451
579
  float,
452
580
  ),
581
+ "object_spawn_time_grid": spaces.Box(
582
+ 0,
583
+ jnp.inf,
584
+ (self.size[1], self.size[0]),
585
+ int,
586
+ ),
453
587
  }
454
588
  )
455
589
 
foragax/objects.py CHANGED
@@ -18,6 +18,7 @@ class BaseForagaxObject:
18
18
  color: Tuple[int, int, int] = (0, 0, 0),
19
19
  random_respawn: bool = False,
20
20
  max_reward_delay: int = 0,
21
+ expiry_time: Optional[int] = None,
21
22
  ):
22
23
  self.name = name
23
24
  self.blocking = blocking
@@ -25,6 +26,7 @@ class BaseForagaxObject:
25
26
  self.color = color
26
27
  self.random_respawn = random_respawn
27
28
  self.max_reward_delay = max_reward_delay
29
+ self.expiry_time = expiry_time
28
30
 
29
31
  @abc.abstractmethod
30
32
  def reward(self, clock: int, rng: jax.Array) -> float:
@@ -36,6 +38,11 @@ class BaseForagaxObject:
36
38
  """Reward delay function."""
37
39
  raise NotImplementedError
38
40
 
41
+ @abc.abstractmethod
42
+ def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
43
+ """Expiry regeneration delay function."""
44
+ raise NotImplementedError
45
+
39
46
 
40
47
  class DefaultForagaxObject(BaseForagaxObject):
41
48
  """Base class for default objects in the Foragax environment."""
@@ -51,15 +58,24 @@ class DefaultForagaxObject(BaseForagaxObject):
51
58
  random_respawn: bool = False,
52
59
  reward_delay: int = 0,
53
60
  max_reward_delay: Optional[int] = None,
61
+ expiry_time: Optional[int] = None,
62
+ expiry_regen_delay: Tuple[int, int] = (10, 100),
54
63
  ):
55
64
  if max_reward_delay is None:
56
65
  max_reward_delay = reward_delay
57
66
  super().__init__(
58
- name, blocking, collectable, color, random_respawn, max_reward_delay
67
+ name,
68
+ blocking,
69
+ collectable,
70
+ color,
71
+ random_respawn,
72
+ max_reward_delay,
73
+ expiry_time,
59
74
  )
60
75
  self.reward_val = reward
61
76
  self.regen_delay_range = regen_delay
62
77
  self.reward_delay_val = reward_delay
78
+ self.expiry_regen_delay_range = expiry_regen_delay
63
79
 
64
80
  def reward(self, clock: int, rng: jax.Array) -> float:
65
81
  """Default reward function."""
@@ -74,6 +90,11 @@ class DefaultForagaxObject(BaseForagaxObject):
74
90
  """Default reward delay function."""
75
91
  return self.reward_delay_val
76
92
 
93
+ def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
94
+ """Default expiry regeneration delay function."""
95
+ min_delay, max_delay = self.expiry_regen_delay_range
96
+ return jax.random.randint(rng, (), min_delay, max_delay)
97
+
77
98
 
78
99
  class NormalRegenForagaxObject(DefaultForagaxObject):
79
100
  """Object with regeneration delay from a normal distribution."""
@@ -89,7 +110,16 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
89
110
  random_respawn: bool = False,
90
111
  reward_delay: int = 0,
91
112
  max_reward_delay: Optional[int] = None,
113
+ expiry_time: Optional[int] = None,
114
+ mean_expiry_regen_delay: Optional[int] = None,
115
+ std_expiry_regen_delay: Optional[int] = None,
92
116
  ):
117
+ # If expiry regen delays not provided, use same as normal regen
118
+ if mean_expiry_regen_delay is None:
119
+ mean_expiry_regen_delay = mean_regen_delay
120
+ if std_expiry_regen_delay is None:
121
+ std_expiry_regen_delay = std_regen_delay
122
+
93
123
  super().__init__(
94
124
  name=name,
95
125
  reward=reward,
@@ -99,15 +129,27 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
99
129
  random_respawn=random_respawn,
100
130
  reward_delay=reward_delay,
101
131
  max_reward_delay=max_reward_delay,
132
+ expiry_time=expiry_time,
133
+ expiry_regen_delay=(mean_expiry_regen_delay, mean_expiry_regen_delay),
102
134
  )
103
135
  self.mean_regen_delay = mean_regen_delay
104
136
  self.std_regen_delay = std_regen_delay
137
+ self.mean_expiry_regen_delay = mean_expiry_regen_delay
138
+ self.std_expiry_regen_delay = std_expiry_regen_delay
105
139
 
106
140
  def regen_delay(self, clock: int, rng: jax.Array) -> int:
107
141
  """Regeneration delay from a normal distribution."""
108
142
  delay = self.mean_regen_delay + jax.random.normal(rng) * self.std_regen_delay
109
143
  return jnp.maximum(0, delay).astype(int)
110
144
 
145
+ def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
146
+ """Expiry regeneration delay from a normal distribution."""
147
+ delay = (
148
+ self.mean_expiry_regen_delay
149
+ + jax.random.normal(rng) * self.std_expiry_regen_delay
150
+ )
151
+ return jnp.maximum(0, delay).astype(int)
152
+
111
153
 
112
154
  class WeatherObject(NormalRegenForagaxObject):
113
155
  """Object with reward based on temperature data."""
@@ -124,6 +166,9 @@ class WeatherObject(NormalRegenForagaxObject):
124
166
  random_respawn: bool = False,
125
167
  reward_delay: int = 0,
126
168
  max_reward_delay: Optional[int] = None,
169
+ expiry_time: Optional[int] = None,
170
+ mean_expiry_regen_delay: Optional[int] = None,
171
+ std_expiry_regen_delay: Optional[int] = None,
127
172
  ):
128
173
  super().__init__(
129
174
  name=name,
@@ -134,6 +179,9 @@ class WeatherObject(NormalRegenForagaxObject):
134
179
  random_respawn=random_respawn,
135
180
  reward_delay=reward_delay,
136
181
  max_reward_delay=max_reward_delay,
182
+ expiry_time=expiry_time,
183
+ mean_expiry_regen_delay=mean_expiry_regen_delay,
184
+ std_expiry_regen_delay=std_expiry_regen_delay,
137
185
  )
138
186
  self.rewards = rewards * multiplier
139
187
  self.repeat = repeat
@@ -332,6 +380,44 @@ GREEN_FAKE_UNIFORM_RANDOM = DefaultForagaxObject(
332
380
  random_respawn=True,
333
381
  )
334
382
 
383
+ # Random respawn variants with expiry
384
+ BROWN_MOREL_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
385
+ name="brown_morel",
386
+ reward=10.0,
387
+ collectable=True,
388
+ color=(63, 30, 25),
389
+ regen_delay=(90, 110),
390
+ random_respawn=True,
391
+ expiry_time=500,
392
+ )
393
+ BROWN_OYSTER_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
394
+ name="brown_oyster",
395
+ reward=1.0,
396
+ collectable=True,
397
+ color=(63, 30, 25),
398
+ regen_delay=(9, 11),
399
+ random_respawn=True,
400
+ expiry_time=500,
401
+ )
402
+ GREEN_DEATHCAP_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
403
+ name="green_deathcap",
404
+ reward=-5.0,
405
+ collectable=True,
406
+ color=(0, 255, 0),
407
+ regen_delay=(9, 11),
408
+ random_respawn=True,
409
+ expiry_time=500,
410
+ )
411
+ GREEN_FAKE_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
412
+ name="green_fake",
413
+ reward=0.0,
414
+ collectable=True,
415
+ color=(0, 255, 0),
416
+ regen_delay=(9, 11),
417
+ random_respawn=True,
418
+ expiry_time=500,
419
+ )
420
+
335
421
 
336
422
  def create_weather_objects(
337
423
  file_index: int = 0,
@@ -340,6 +426,9 @@ def create_weather_objects(
340
426
  same_color: bool = False,
341
427
  random_respawn: bool = False,
342
428
  reward_delay: int = 0,
429
+ expiry_time: Optional[int] = None,
430
+ mean_expiry_regen_delay: Optional[int] = None,
431
+ std_expiry_regen_delay: Optional[int] = None,
343
432
  ):
344
433
  """Create HOT and COLD WeatherObject instances using the specified file.
345
434
 
@@ -348,6 +437,11 @@ def create_weather_objects(
348
437
  repeat: How many steps each temperature value repeats for.
349
438
  multiplier: Base multiplier applied to HOT; COLD will use -multiplier.
350
439
  same_color: If True, both HOT and COLD use the same color.
440
+ random_respawn: If True, objects respawn at random locations.
441
+ reward_delay: Number of steps before reward is delivered.
442
+ expiry_time: Time steps before object expires (None = no expiry).
443
+ mean_expiry_regen_delay: Mean delay for expiry respawn.
444
+ std_expiry_regen_delay: Standard deviation for expiry respawn delay.
351
445
 
352
446
  Returns:
353
447
  A tuple (HOT, COLD) of WeatherObject instances.
@@ -370,6 +464,9 @@ def create_weather_objects(
370
464
  color=hot_color,
371
465
  random_respawn=random_respawn,
372
466
  reward_delay=reward_delay,
467
+ expiry_time=expiry_time,
468
+ mean_expiry_regen_delay=mean_expiry_regen_delay,
469
+ std_expiry_regen_delay=std_expiry_regen_delay,
373
470
  )
374
471
 
375
472
  cold_color = hot_color if same_color else (0, 255, 255)
@@ -381,6 +478,9 @@ def create_weather_objects(
381
478
  color=cold_color,
382
479
  random_respawn=random_respawn,
383
480
  reward_delay=reward_delay,
481
+ expiry_time=expiry_time,
482
+ mean_expiry_regen_delay=mean_expiry_regen_delay,
483
+ std_expiry_regen_delay=std_expiry_regen_delay,
384
484
  )
385
485
 
386
486
  return hot, cold
foragax/registry.py CHANGED
@@ -12,18 +12,22 @@ from foragax.objects import (
12
12
  BROWN_MOREL_2,
13
13
  BROWN_MOREL_UNIFORM,
14
14
  BROWN_MOREL_UNIFORM_RANDOM,
15
+ BROWN_MOREL_UNIFORM_RANDOM_EXPIRY,
15
16
  BROWN_OYSTER,
16
17
  BROWN_OYSTER_UNIFORM,
17
18
  BROWN_OYSTER_UNIFORM_RANDOM,
19
+ BROWN_OYSTER_UNIFORM_RANDOM_EXPIRY,
18
20
  GREEN_DEATHCAP,
19
21
  GREEN_DEATHCAP_2,
20
22
  GREEN_DEATHCAP_3,
21
23
  GREEN_DEATHCAP_UNIFORM,
22
24
  GREEN_DEATHCAP_UNIFORM_RANDOM,
25
+ GREEN_DEATHCAP_UNIFORM_RANDOM_EXPIRY,
23
26
  GREEN_FAKE,
24
27
  GREEN_FAKE_2,
25
28
  GREEN_FAKE_UNIFORM,
26
29
  GREEN_FAKE_UNIFORM_RANDOM,
30
+ GREEN_FAKE_UNIFORM_RANDOM_EXPIRY,
27
31
  LARGE_MOREL,
28
32
  LARGE_OYSTER,
29
33
  MEDIUM_MOREL,
@@ -304,6 +308,26 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
304
308
  "nowrap": True,
305
309
  "deterministic_spawn": True,
306
310
  },
311
+ "ForagaxTwoBiome-v17": {
312
+ "size": (15, 15),
313
+ "aperture_size": None,
314
+ "objects": (
315
+ BROWN_MOREL_UNIFORM_RANDOM_EXPIRY,
316
+ BROWN_OYSTER_UNIFORM_RANDOM_EXPIRY,
317
+ GREEN_DEATHCAP_UNIFORM_RANDOM_EXPIRY,
318
+ GREEN_FAKE_UNIFORM_RANDOM_EXPIRY,
319
+ ),
320
+ "biomes": (
321
+ # Morel biome
322
+ Biome(start=(3, 0), stop=(5, 15), object_frequencies=(0.25, 0.0, 0.5, 0.0)),
323
+ # Oyster biome
324
+ Biome(
325
+ start=(10, 0), stop=(12, 15), object_frequencies=(0.0, 0.25, 0.0, 0.5)
326
+ ),
327
+ ),
328
+ "nowrap": False,
329
+ "deterministic_spawn": True,
330
+ },
307
331
  "ForagaxTwoBiomeSmall-v1": {
308
332
  "size": (16, 8),
309
333
  "aperture_size": None,