plancraft 0.3.29__py3-none-any.whl → 0.3.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plancraft/evaluator.py +100 -119
- plancraft/models/act.py +0 -3
- plancraft/models/base.py +0 -7
- plancraft/models/dummy.py +0 -6
- plancraft/models/oracle.py +0 -6
- {plancraft-0.3.29.dist-info → plancraft-0.3.31.dist-info}/METADATA +1 -1
- {plancraft-0.3.29.dist-info → plancraft-0.3.31.dist-info}/RECORD +9 -9
- {plancraft-0.3.29.dist-info → plancraft-0.3.31.dist-info}/WHEEL +0 -0
- {plancraft-0.3.29.dist-info → plancraft-0.3.31.dist-info}/licenses/LICENSE +0 -0
plancraft/evaluator.py
CHANGED
@@ -2,6 +2,7 @@ import json
|
|
2
2
|
import os
|
3
3
|
from typing import Optional
|
4
4
|
from copy import deepcopy
|
5
|
+
from collections import deque
|
5
6
|
|
6
7
|
import imageio
|
7
8
|
from tqdm import tqdm
|
@@ -244,16 +245,6 @@ class Evaluator:
|
|
244
245
|
# check if the episode is done
|
245
246
|
success = self.check_done(observation["inventory"], example.target)
|
246
247
|
|
247
|
-
# update model with success or failure
|
248
|
-
# observation is the next state after the action (s1)
|
249
|
-
# history is the dialogue history
|
250
|
-
# -- the last message contains the action taken (a0)
|
251
|
-
# -- the second to last message is the observation (s0)
|
252
|
-
# success is whether the episode is sucessful (r)
|
253
|
-
model.update(
|
254
|
-
observation=observation, history=history, success=success, action=action
|
255
|
-
)
|
256
|
-
|
257
248
|
# exit if success
|
258
249
|
if success or isinstance(action, StopAction):
|
259
250
|
break
|
@@ -273,161 +264,151 @@ class Evaluator:
|
|
273
264
|
self,
|
274
265
|
examples: list[PlancraftExample],
|
275
266
|
model,
|
267
|
+
batch_size: int = 4,
|
268
|
+
callback_fn: Optional[callable] = None,
|
276
269
|
) -> list:
|
277
270
|
"""
|
278
|
-
|
271
|
+
Processes examples in batches with dynamic replacement from a queue.
|
279
272
|
|
280
|
-
|
281
|
-
|
273
|
+
Args:
|
274
|
+
examples: List of examples to process
|
275
|
+
model: Model to use for evaluation
|
276
|
+
batch_size: Maximum number of concurrent environments
|
277
|
+
callback_fn: Optional callback function to call after each result
|
282
278
|
"""
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
279
|
+
pending_examples = deque(examples)
|
280
|
+
active_examples = []
|
281
|
+
active_environments = []
|
282
|
+
active_histories = []
|
283
|
+
active_observations = []
|
284
|
+
results = {ex.id: None for ex in examples}
|
285
|
+
|
286
|
+
# Initialize first batch
|
287
|
+
while len(active_examples) < batch_size and pending_examples:
|
288
|
+
example = pending_examples.popleft()
|
289
|
+
env = PlancraftEnvironment(
|
290
|
+
inventory=deepcopy(example.slotted_inventory),
|
288
291
|
resolution=self.resolution,
|
289
292
|
)
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
# Track which environments are still active
|
295
|
-
active_mask = [True for _ in range(len(examples))]
|
296
|
-
results = [None for _ in range(len(examples))]
|
297
|
-
observations = []
|
298
|
-
|
299
|
-
# Initialize observations (s0) and user messages from environment
|
300
|
-
for i in range(len(examples)):
|
301
|
-
obs = environments[i].step()
|
302
|
-
obs["target"] = examples[i].target
|
293
|
+
history = self.create_history()
|
294
|
+
obs = env.step()
|
295
|
+
obs["target"] = example.target
|
303
296
|
obs["message"] = self.convert_observation_to_message(obs, model=model)
|
304
|
-
observations.append(obs)
|
305
|
-
|
306
|
-
# Process until all done or max steps reached
|
307
|
-
while any(active_mask) and all(
|
308
|
-
history.num_steps < self.max_steps for history in histories
|
309
|
-
):
|
310
|
-
# Gather active environments
|
311
|
-
active_indices = [
|
312
|
-
i
|
313
|
-
for i, active in enumerate(active_mask)
|
314
|
-
if active and histories[i].num_steps < self.max_steps
|
315
|
-
]
|
316
|
-
if not active_indices:
|
317
|
-
break
|
318
|
-
|
319
|
-
# For each active environment, add new obs to history for next iteration
|
320
|
-
for env_idx in active_indices:
|
321
|
-
if active_mask[env_idx]:
|
322
|
-
histories[env_idx].add_observation_to_history(observations[env_idx])
|
323
|
-
histories[env_idx].add_message_to_history(
|
324
|
-
content=observations[env_idx]["message"], role="user"
|
325
|
-
)
|
326
297
|
|
327
|
-
|
328
|
-
|
298
|
+
active_examples.append(example)
|
299
|
+
active_environments.append(env)
|
300
|
+
active_histories.append(history)
|
301
|
+
active_observations.append(obs)
|
302
|
+
|
303
|
+
# Process until all examples are done
|
304
|
+
while active_examples:
|
305
|
+
# Add observations to histories
|
306
|
+
for i in range(len(active_examples)):
|
307
|
+
active_histories[i].add_observation_to_history(active_observations[i])
|
308
|
+
active_histories[i].add_message_to_history(
|
309
|
+
content=active_observations[i]["message"], role="user"
|
310
|
+
)
|
329
311
|
|
330
|
-
#
|
312
|
+
# Get model predictions for current batch
|
331
313
|
raw_actions = model.batch_step(
|
332
|
-
|
314
|
+
active_observations, dialogue_histories=active_histories
|
333
315
|
)
|
334
316
|
|
335
|
-
# Process each
|
317
|
+
# Process each active environment
|
318
|
+
completed_indices = []
|
336
319
|
successes = []
|
337
320
|
actions = []
|
338
|
-
|
339
|
-
|
321
|
+
|
322
|
+
for i, (example, raw_action) in enumerate(
|
323
|
+
zip(active_examples, raw_actions)
|
324
|
+
):
|
325
|
+
# Handle model output
|
340
326
|
if isinstance(raw_action, PlancraftModelOutput):
|
341
|
-
|
327
|
+
active_histories[i].add_message_to_history(
|
342
328
|
content=raw_action.action,
|
343
329
|
role="assistant",
|
344
330
|
**(raw_action.kwargs or {}),
|
345
331
|
)
|
346
332
|
raw_action = raw_action.action
|
347
|
-
elif isinstance(raw_action, str):
|
348
|
-
histories[env_idx].add_message_to_history(
|
349
|
-
content=raw_action, role="assistant"
|
350
|
-
)
|
351
333
|
else:
|
352
|
-
|
353
|
-
|
334
|
+
active_histories[i].add_message_to_history(
|
335
|
+
content=raw_action, role="assistant"
|
354
336
|
)
|
355
337
|
|
356
|
-
# Parse action
|
338
|
+
# Parse and execute action
|
357
339
|
action = self.parse_raw_model_response(
|
358
340
|
raw_action,
|
359
|
-
observation=
|
360
|
-
history=
|
341
|
+
observation=active_observations[i],
|
342
|
+
history=active_histories[i],
|
361
343
|
)
|
362
344
|
actions.append(action)
|
363
345
|
success = False
|
364
|
-
|
346
|
+
|
365
347
|
if isinstance(action, StopAction):
|
366
|
-
|
367
|
-
|
368
|
-
success = examples[env_idx].impossible
|
369
|
-
observations[env_idx] = None
|
370
|
-
# If parsed action is a string, it's a message
|
348
|
+
success = example.impossible
|
349
|
+
active_observations[i] = None
|
371
350
|
elif isinstance(action, str):
|
372
|
-
obs =
|
373
|
-
obs["target"] =
|
351
|
+
obs = active_environments[i].step()
|
352
|
+
obs["target"] = example.target
|
374
353
|
obs["message"] = action
|
375
|
-
|
376
|
-
# Otherwise it's an actual environment action
|
354
|
+
active_observations[i] = obs
|
377
355
|
else:
|
378
|
-
obs =
|
379
|
-
obs["target"] =
|
356
|
+
obs = active_environments[i].step(action)
|
357
|
+
obs["target"] = example.target
|
380
358
|
obs["message"] = self.convert_observation_to_message(
|
381
359
|
obs, model=model
|
382
360
|
)
|
383
|
-
|
384
|
-
success = self.check_done(
|
385
|
-
obs["inventory"], examples[env_idx].target
|
386
|
-
)
|
361
|
+
active_observations[i] = obs
|
362
|
+
success = self.check_done(obs["inventory"], example.target)
|
387
363
|
|
388
364
|
successes.append(success)
|
389
365
|
|
390
|
-
#
|
366
|
+
# Check if environment is done
|
391
367
|
if (
|
392
368
|
success
|
393
369
|
or isinstance(action, StopAction)
|
394
|
-
or
|
370
|
+
or active_histories[i].num_steps >= self.max_steps
|
395
371
|
):
|
396
|
-
|
397
|
-
results[env_idx] = {
|
372
|
+
results[example.id] = {
|
398
373
|
"success": success,
|
399
|
-
"recipe_type":
|
400
|
-
"complexity":
|
401
|
-
"number_of_steps":
|
402
|
-
"model_trace":
|
403
|
-
"example_id":
|
404
|
-
"images":
|
374
|
+
"recipe_type": example.recipe_type,
|
375
|
+
"complexity": example.complexity_split,
|
376
|
+
"number_of_steps": active_histories[i].num_steps,
|
377
|
+
"model_trace": active_histories[i].trace(),
|
378
|
+
"example_id": example.id,
|
379
|
+
"images": active_histories[i].images,
|
405
380
|
}
|
381
|
+
completed_indices.append(i)
|
382
|
+
if callback_fn:
|
383
|
+
callback_fn(results[example.id])
|
384
|
+
|
385
|
+
# Remove completed environments and replace with new ones
|
386
|
+
for i in reversed(completed_indices):
|
387
|
+
active_examples.pop(i)
|
388
|
+
active_environments.pop(i)
|
389
|
+
active_histories.pop(i)
|
390
|
+
active_observations.pop(i)
|
391
|
+
|
392
|
+
# Add new environment if there are pending examples
|
393
|
+
if pending_examples:
|
394
|
+
example = pending_examples.popleft()
|
395
|
+
env = PlancraftEnvironment(
|
396
|
+
inventory=deepcopy(example.slotted_inventory),
|
397
|
+
resolution=self.resolution,
|
398
|
+
)
|
399
|
+
history = self.create_history()
|
400
|
+
obs = env.step()
|
401
|
+
obs["target"] = example.target
|
402
|
+
obs["message"] = self.convert_observation_to_message(
|
403
|
+
obs, model=model
|
404
|
+
)
|
406
405
|
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
observations=batch_observations,
|
412
|
-
histories=batch_histories,
|
413
|
-
successes=successes,
|
414
|
-
actions=actions,
|
415
|
-
)
|
416
|
-
|
417
|
-
# Fill in results for any environment that never completed
|
418
|
-
for i, result in enumerate(results):
|
419
|
-
if result is None:
|
420
|
-
results[i] = {
|
421
|
-
"success": False,
|
422
|
-
"recipe_type": examples[i].recipe_type,
|
423
|
-
"complexity": examples[i].complexity_split,
|
424
|
-
"number_of_steps": histories[i].num_steps,
|
425
|
-
"model_trace": histories[i].trace(),
|
426
|
-
"example_id": examples[i].id,
|
427
|
-
"images": histories[i].images,
|
428
|
-
}
|
406
|
+
active_examples.append(example)
|
407
|
+
active_environments.append(env)
|
408
|
+
active_histories.append(history)
|
409
|
+
active_observations.append(obs)
|
429
410
|
|
430
|
-
return results
|
411
|
+
return list(results.values())
|
431
412
|
|
432
413
|
def eval_all_examples(self, model, progress_bar=False) -> list:
|
433
414
|
results = []
|
plancraft/models/act.py
CHANGED
plancraft/models/base.py
CHANGED
@@ -33,10 +33,3 @@ class PlancraftBaseModel(abc.ABC):
|
|
33
33
|
Reset the model state - ready for a new episode
|
34
34
|
"""
|
35
35
|
raise NotImplementedError()
|
36
|
-
|
37
|
-
@abc.abstractmethod
|
38
|
-
def update(self, **kwargs) -> None:
|
39
|
-
"""
|
40
|
-
Update the model state based on the dialogue history
|
41
|
-
"""
|
42
|
-
raise NotImplementedError()
|
plancraft/models/dummy.py
CHANGED
@@ -45,9 +45,3 @@ class DummyModel(PlancraftBaseModel):
|
|
45
45
|
self, observations: list[dict], **kwargs
|
46
46
|
) -> list[PlancraftModelOutput]:
|
47
47
|
return [self.step(observation) for observation in observations]
|
48
|
-
|
49
|
-
def update(self, **kwargs):
|
50
|
-
pass
|
51
|
-
|
52
|
-
def batch_update(self, **kwargs):
|
53
|
-
pass
|
plancraft/models/oracle.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
plancraft/config.py,sha256=ShsFRlJ7plsl3ToD9fiO_4LDQuXdbjNV6Xp6o3Yk2Yg,4315
|
3
|
-
plancraft/evaluator.py,sha256=
|
3
|
+
plancraft/evaluator.py,sha256=VteLAT_rPogw8NYZos7jEuuakyfE_3CsFuv6A39Geyw,17614
|
4
4
|
plancraft/generate_dataset.py,sha256=DlrU-PmvWqSNJD1g1-8Lpb8n3N-Ogw3rje1nrRzjGKs,2382
|
5
5
|
plancraft/utils.py,sha256=VhnxMihh6pRhNjQTK5HDc0FYWmF9_EcQyRP_a7fbIZA,7156
|
6
6
|
plancraft/data/test.json,sha256=4jWfYMAVuZCFmGB4iZJAjlh9_8jXECdaGp8xn7_tAM4,1317131
|
@@ -1912,15 +1912,15 @@ plancraft/environment/tags/wooden_stairs.json,sha256=GCr2_5UGPMYZECqQ_5NYSvbwuwt
|
|
1912
1912
|
plancraft/environment/tags/wooden_trapdoors.json,sha256=DbjfwoHJL8VuYWV61A1uDqW7LJsGlOP4eoxcGIQVYr4,303
|
1913
1913
|
plancraft/environment/tags/wool.json,sha256=Z59l4mdPztVZBFaglJ4mV9H2OnyCVzhqQRi2dduak78,496
|
1914
1914
|
plancraft/models/__init__.py,sha256=TBrarn93qt4IFJRNqtzOfaA8jGMPCgD7DFs-M84ipmk,510
|
1915
|
-
plancraft/models/act.py,sha256=
|
1916
|
-
plancraft/models/base.py,sha256=
|
1915
|
+
plancraft/models/act.py,sha256=6Xb8rylg3OngOraVFgduH_hQR62VcoyTeFntN4q3hsQ,2691
|
1916
|
+
plancraft/models/base.py,sha256=S8EdkqWpn8nE1WcrqDoA4Hx4p52qEttGxnqjIPWvl3Q,852
|
1917
1917
|
plancraft/models/bbox_model.py,sha256=3b1IEspoHiVUR6GOWjEbp4YoxRhGkzKt-eOiwaN8NXo,17091
|
1918
|
-
plancraft/models/dummy.py,sha256=
|
1918
|
+
plancraft/models/dummy.py,sha256=_NUTviv5ye6KGzODRt0Zykk8shsek0QBqWCeZW3ldSQ,1495
|
1919
1919
|
plancraft/models/generators.py,sha256=F76_iPiqxUjDIrQwF58tzM0bLM91OkZJ0sBqBuki5wY,13939
|
1920
|
-
plancraft/models/oracle.py,sha256=
|
1920
|
+
plancraft/models/oracle.py,sha256=f-0KWlBuHy6wcxmDsxM3MQ_QwfBstzfbA26mlk1MgLA,1657
|
1921
1921
|
plancraft/models/utils.py,sha256=E-sZohvolWgGbpHQKgAgkgIfUJoVnT5pMt6JP8xLHKg,4034
|
1922
1922
|
plancraft/train/dataset.py,sha256=oFqEd4LG9oEQ-71teh0Wf7-jJbtybT2ZibfM2bBdBkM,5474
|
1923
|
-
plancraft-0.3.
|
1924
|
-
plancraft-0.3.
|
1925
|
-
plancraft-0.3.
|
1926
|
-
plancraft-0.3.
|
1923
|
+
plancraft-0.3.31.dist-info/METADATA,sha256=gU6j3SQEGdXIeW1pab_Pz6hspDhl_g0vaPIkIXRScYo,11148
|
1924
|
+
plancraft-0.3.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
1925
|
+
plancraft-0.3.31.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
|
1926
|
+
plancraft-0.3.31.dist-info/RECORD,,
|
File without changes
|
File without changes
|