plancraft 0.3.29__py3-none-any.whl → 0.3.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plancraft/evaluator.py +99 -104
- {plancraft-0.3.29.dist-info → plancraft-0.3.30.dist-info}/METADATA +1 -1
- {plancraft-0.3.29.dist-info → plancraft-0.3.30.dist-info}/RECORD +5 -5
- {plancraft-0.3.29.dist-info → plancraft-0.3.30.dist-info}/WHEEL +0 -0
- {plancraft-0.3.29.dist-info → plancraft-0.3.30.dist-info}/licenses/LICENSE +0 -0
plancraft/evaluator.py
CHANGED
@@ -2,6 +2,7 @@ import json
|
|
2
2
|
import os
|
3
3
|
from typing import Optional
|
4
4
|
from copy import deepcopy
|
5
|
+
from collections import deque
|
5
6
|
|
6
7
|
import imageio
|
7
8
|
from tqdm import tqdm
|
@@ -273,161 +274,155 @@ class Evaluator:
|
|
273
274
|
self,
|
274
275
|
examples: list[PlancraftExample],
|
275
276
|
model,
|
277
|
+
batch_size: int = 4,
|
276
278
|
) -> list:
|
277
279
|
"""
|
278
|
-
|
280
|
+
Processes examples in batches with dynamic replacement from a queue.
|
279
281
|
|
280
|
-
|
281
|
-
|
282
|
+
Args:
|
283
|
+
examples: List of examples to process
|
284
|
+
model: Model to use for evaluation
|
285
|
+
batch_size: Maximum number of concurrent environments
|
282
286
|
"""
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
287
|
+
pending_examples = deque(examples)
|
288
|
+
active_examples = []
|
289
|
+
active_environments = []
|
290
|
+
active_histories = []
|
291
|
+
active_observations = []
|
292
|
+
results = {ex.id: None for ex in examples}
|
293
|
+
|
294
|
+
# Initialize first batch
|
295
|
+
while len(active_examples) < batch_size and pending_examples:
|
296
|
+
example = pending_examples.popleft()
|
297
|
+
env = PlancraftEnvironment(
|
298
|
+
inventory=deepcopy(example.slotted_inventory),
|
288
299
|
resolution=self.resolution,
|
289
300
|
)
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
# Track which environments are still active
|
295
|
-
active_mask = [True for _ in range(len(examples))]
|
296
|
-
results = [None for _ in range(len(examples))]
|
297
|
-
observations = []
|
298
|
-
|
299
|
-
# Initialize observations (s0) and user messages from environment
|
300
|
-
for i in range(len(examples)):
|
301
|
-
obs = environments[i].step()
|
302
|
-
obs["target"] = examples[i].target
|
301
|
+
history = self.create_history()
|
302
|
+
obs = env.step()
|
303
|
+
obs["target"] = example.target
|
303
304
|
obs["message"] = self.convert_observation_to_message(obs, model=model)
|
304
|
-
observations.append(obs)
|
305
|
-
|
306
|
-
# Process until all done or max steps reached
|
307
|
-
while any(active_mask) and all(
|
308
|
-
history.num_steps < self.max_steps for history in histories
|
309
|
-
):
|
310
|
-
# Gather active environments
|
311
|
-
active_indices = [
|
312
|
-
i
|
313
|
-
for i, active in enumerate(active_mask)
|
314
|
-
if active and histories[i].num_steps < self.max_steps
|
315
|
-
]
|
316
|
-
if not active_indices:
|
317
|
-
break
|
318
305
|
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
306
|
+
active_examples.append(example)
|
307
|
+
active_environments.append(env)
|
308
|
+
active_histories.append(history)
|
309
|
+
active_observations.append(obs)
|
310
|
+
|
311
|
+
# Process until all examples are done
|
312
|
+
while active_examples:
|
313
|
+
# Add observations to histories
|
314
|
+
for i in range(len(active_examples)):
|
315
|
+
active_histories[i].add_observation_to_history(active_observations[i])
|
316
|
+
active_histories[i].add_message_to_history(
|
317
|
+
content=active_observations[i]["message"], role="user"
|
318
|
+
)
|
329
319
|
|
330
|
-
#
|
320
|
+
# Get model predictions for current batch
|
331
321
|
raw_actions = model.batch_step(
|
332
|
-
|
322
|
+
active_observations, dialogue_histories=active_histories
|
333
323
|
)
|
334
324
|
|
335
|
-
# Process each
|
325
|
+
# Process each active environment
|
326
|
+
completed_indices = []
|
336
327
|
successes = []
|
337
328
|
actions = []
|
338
|
-
|
339
|
-
|
329
|
+
|
330
|
+
for i, (example, raw_action) in enumerate(
|
331
|
+
zip(active_examples, raw_actions)
|
332
|
+
):
|
333
|
+
# Handle model output
|
340
334
|
if isinstance(raw_action, PlancraftModelOutput):
|
341
|
-
|
335
|
+
active_histories[i].add_message_to_history(
|
342
336
|
content=raw_action.action,
|
343
337
|
role="assistant",
|
344
338
|
**(raw_action.kwargs or {}),
|
345
339
|
)
|
346
340
|
raw_action = raw_action.action
|
347
|
-
elif isinstance(raw_action, str):
|
348
|
-
histories[env_idx].add_message_to_history(
|
349
|
-
content=raw_action, role="assistant"
|
350
|
-
)
|
351
341
|
else:
|
352
|
-
|
353
|
-
|
342
|
+
active_histories[i].add_message_to_history(
|
343
|
+
content=raw_action, role="assistant"
|
354
344
|
)
|
355
345
|
|
356
|
-
# Parse action
|
346
|
+
# Parse and execute action
|
357
347
|
action = self.parse_raw_model_response(
|
358
348
|
raw_action,
|
359
|
-
observation=
|
360
|
-
history=
|
349
|
+
observation=active_observations[i],
|
350
|
+
history=active_histories[i],
|
361
351
|
)
|
362
352
|
actions.append(action)
|
363
353
|
success = False
|
364
|
-
|
354
|
+
|
365
355
|
if isinstance(action, StopAction):
|
366
|
-
|
367
|
-
|
368
|
-
success = examples[env_idx].impossible
|
369
|
-
observations[env_idx] = None
|
370
|
-
# If parsed action is a string, it's a message
|
356
|
+
success = example.impossible
|
357
|
+
active_observations[i] = None
|
371
358
|
elif isinstance(action, str):
|
372
|
-
obs =
|
373
|
-
obs["target"] =
|
359
|
+
obs = active_environments[i].step()
|
360
|
+
obs["target"] = example.target
|
374
361
|
obs["message"] = action
|
375
|
-
|
376
|
-
# Otherwise it's an actual environment action
|
362
|
+
active_observations[i] = obs
|
377
363
|
else:
|
378
|
-
obs =
|
379
|
-
obs["target"] =
|
364
|
+
obs = active_environments[i].step(action)
|
365
|
+
obs["target"] = example.target
|
380
366
|
obs["message"] = self.convert_observation_to_message(
|
381
367
|
obs, model=model
|
382
368
|
)
|
383
|
-
|
384
|
-
success = self.check_done(
|
385
|
-
obs["inventory"], examples[env_idx].target
|
386
|
-
)
|
369
|
+
active_observations[i] = obs
|
370
|
+
success = self.check_done(obs["inventory"], example.target)
|
387
371
|
|
388
372
|
successes.append(success)
|
389
373
|
|
390
|
-
#
|
374
|
+
# Check if environment is done
|
391
375
|
if (
|
392
376
|
success
|
393
377
|
or isinstance(action, StopAction)
|
394
|
-
or
|
378
|
+
or active_histories[i].num_steps >= self.max_steps
|
395
379
|
):
|
396
|
-
|
397
|
-
results[env_idx] = {
|
380
|
+
results[example.id] = {
|
398
381
|
"success": success,
|
399
|
-
"recipe_type":
|
400
|
-
"complexity":
|
401
|
-
"number_of_steps":
|
402
|
-
"model_trace":
|
403
|
-
"example_id":
|
404
|
-
"images":
|
382
|
+
"recipe_type": example.recipe_type,
|
383
|
+
"complexity": example.complexity_split,
|
384
|
+
"number_of_steps": active_histories[i].num_steps,
|
385
|
+
"model_trace": active_histories[i].trace(),
|
386
|
+
"example_id": example.id,
|
387
|
+
"images": active_histories[i].images,
|
405
388
|
}
|
389
|
+
completed_indices.append(i)
|
406
390
|
|
407
|
-
# Update
|
408
|
-
batch_observations = [observations[i] for i in active_indices]
|
409
|
-
batch_histories = [histories[i] for i in active_indices]
|
391
|
+
# Update model
|
410
392
|
model.batch_update(
|
411
|
-
observations=
|
412
|
-
histories=
|
393
|
+
observations=active_observations,
|
394
|
+
histories=active_histories,
|
413
395
|
successes=successes,
|
414
396
|
actions=actions,
|
415
397
|
)
|
416
398
|
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
399
|
+
# Remove completed environments and replace with new ones
|
400
|
+
for i in reversed(completed_indices):
|
401
|
+
active_examples.pop(i)
|
402
|
+
active_environments.pop(i)
|
403
|
+
active_histories.pop(i)
|
404
|
+
active_observations.pop(i)
|
405
|
+
|
406
|
+
# Add new environment if there are pending examples
|
407
|
+
if pending_examples:
|
408
|
+
example = pending_examples.popleft()
|
409
|
+
env = PlancraftEnvironment(
|
410
|
+
inventory=deepcopy(example.slotted_inventory),
|
411
|
+
resolution=self.resolution,
|
412
|
+
)
|
413
|
+
history = self.create_history()
|
414
|
+
obs = env.step()
|
415
|
+
obs["target"] = example.target
|
416
|
+
obs["message"] = self.convert_observation_to_message(
|
417
|
+
obs, model=model
|
418
|
+
)
|
419
|
+
|
420
|
+
active_examples.append(example)
|
421
|
+
active_environments.append(env)
|
422
|
+
active_histories.append(history)
|
423
|
+
active_observations.append(obs)
|
429
424
|
|
430
|
-
return results
|
425
|
+
return list(results.values())
|
431
426
|
|
432
427
|
def eval_all_examples(self, model, progress_bar=False) -> list:
|
433
428
|
results = []
|
@@ -1,6 +1,6 @@
|
|
1
1
|
plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
plancraft/config.py,sha256=ShsFRlJ7plsl3ToD9fiO_4LDQuXdbjNV6Xp6o3Yk2Yg,4315
|
3
|
-
plancraft/evaluator.py,sha256=
|
3
|
+
plancraft/evaluator.py,sha256=VFJnfitixU2Y4RTxp0lDALoCSFMMwMJPgSQC0Y0tmH8,18121
|
4
4
|
plancraft/generate_dataset.py,sha256=DlrU-PmvWqSNJD1g1-8Lpb8n3N-Ogw3rje1nrRzjGKs,2382
|
5
5
|
plancraft/utils.py,sha256=VhnxMihh6pRhNjQTK5HDc0FYWmF9_EcQyRP_a7fbIZA,7156
|
6
6
|
plancraft/data/test.json,sha256=4jWfYMAVuZCFmGB4iZJAjlh9_8jXECdaGp8xn7_tAM4,1317131
|
@@ -1920,7 +1920,7 @@ plancraft/models/generators.py,sha256=F76_iPiqxUjDIrQwF58tzM0bLM91OkZJ0sBqBuki5w
|
|
1920
1920
|
plancraft/models/oracle.py,sha256=jmt_kBBNXt0VWUX7q6OHkJoRZWItCMy4qGH5qbLSc1c,1755
|
1921
1921
|
plancraft/models/utils.py,sha256=E-sZohvolWgGbpHQKgAgkgIfUJoVnT5pMt6JP8xLHKg,4034
|
1922
1922
|
plancraft/train/dataset.py,sha256=oFqEd4LG9oEQ-71teh0Wf7-jJbtybT2ZibfM2bBdBkM,5474
|
1923
|
-
plancraft-0.3.
|
1924
|
-
plancraft-0.3.
|
1925
|
-
plancraft-0.3.
|
1926
|
-
plancraft-0.3.
|
1923
|
+
plancraft-0.3.30.dist-info/METADATA,sha256=tltUHYqXhfDXfsQGU5NLhEp6TjR41g6X0OWFn5dpttg,11148
|
1924
|
+
plancraft-0.3.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
1925
|
+
plancraft-0.3.30.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
|
1926
|
+
plancraft-0.3.30.dist-info/RECORD,,
|
File without changes
|
File without changes
|