plancraft 0.3.18__py3-none-any.whl → 0.3.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plancraft/evaluator.py +12 -9
- {plancraft-0.3.18.dist-info → plancraft-0.3.20.dist-info}/METADATA +1 -1
- {plancraft-0.3.18.dist-info → plancraft-0.3.20.dist-info}/RECORD +5 -5
- {plancraft-0.3.18.dist-info → plancraft-0.3.20.dist-info}/WHEEL +0 -0
- {plancraft-0.3.18.dist-info → plancraft-0.3.20.dist-info}/licenses/LICENSE +0 -0
plancraft/evaluator.py
CHANGED
@@ -4,7 +4,6 @@ from typing import Optional
|
|
4
4
|
from copy import deepcopy
|
5
5
|
|
6
6
|
import imageio
|
7
|
-
from loguru import logger
|
8
7
|
from tqdm import tqdm
|
9
8
|
|
10
9
|
import wandb
|
@@ -293,6 +292,8 @@ class Evaluator:
|
|
293
292
|
# Get observations for all active environments
|
294
293
|
observations = []
|
295
294
|
active_indices = []
|
295
|
+
active_histories = []
|
296
|
+
|
296
297
|
for i, (env, action, active) in enumerate(
|
297
298
|
zip(environments, actions, active_mask)
|
298
299
|
):
|
@@ -311,17 +312,14 @@ class Evaluator:
|
|
311
312
|
"example_id": examples[i].id,
|
312
313
|
"images": histories[i].images,
|
313
314
|
}
|
314
|
-
logger.info("STOP")
|
315
315
|
continue
|
316
316
|
|
317
|
-
|
317
|
+
# Get observation
|
318
318
|
if isinstance(action, str):
|
319
|
-
# Handle message action
|
320
319
|
obs = env.step()
|
321
320
|
obs["target"] = examples[i].target
|
322
321
|
obs["message"] = action
|
323
322
|
else:
|
324
|
-
# Handle environment action
|
325
323
|
obs = env.step(action)
|
326
324
|
obs["target"] = examples[i].target
|
327
325
|
obs["message"] = self.convert_observation_to_message(
|
@@ -342,7 +340,12 @@ class Evaluator:
|
|
342
340
|
}
|
343
341
|
continue
|
344
342
|
|
343
|
+
# Add to batch lists
|
344
|
+
active_indices.append(i)
|
345
345
|
observations.append(obs)
|
346
|
+
active_histories.append(histories[i])
|
347
|
+
|
348
|
+
# Update history
|
346
349
|
histories[i].add_observation_to_history(obs)
|
347
350
|
histories[i].add_message_to_history(content=obs["message"], role="user")
|
348
351
|
steps_taken[i] += 1
|
@@ -351,20 +354,20 @@ class Evaluator:
|
|
351
354
|
break
|
352
355
|
|
353
356
|
# Batch predict actions for active environments
|
354
|
-
active_histories = [histories[i] for i in active_indices]
|
355
357
|
raw_actions = model.batch_step(
|
356
358
|
observations, dialogue_histories=active_histories
|
357
359
|
)
|
358
360
|
|
359
361
|
# Process actions for each active environment
|
360
|
-
for idx, raw_action in
|
361
|
-
|
362
|
+
for batch_idx, (idx, raw_action) in enumerate(
|
363
|
+
zip(active_indices, raw_actions)
|
364
|
+
):
|
362
365
|
histories[idx].add_message_to_history(
|
363
366
|
content=raw_action, role="assistant"
|
364
367
|
)
|
365
368
|
actions[idx] = self.parse_raw_model_response(
|
366
369
|
raw_action,
|
367
|
-
observation=observations[
|
370
|
+
observation=observations[batch_idx],
|
368
371
|
history=histories[idx],
|
369
372
|
)
|
370
373
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
plancraft/config.py,sha256=ShsFRlJ7plsl3ToD9fiO_4LDQuXdbjNV6Xp6o3Yk2Yg,4315
|
3
|
-
plancraft/evaluator.py,sha256=
|
3
|
+
plancraft/evaluator.py,sha256=R_RZN9AL_ae0rIvj7HLhYolTpCVMuhPTJfIrmyoLaX4,16326
|
4
4
|
plancraft/generate_dataset.py,sha256=DlrU-PmvWqSNJD1g1-8Lpb8n3N-Ogw3rje1nrRzjGKs,2382
|
5
5
|
plancraft/utils.py,sha256=67UUDMSv8TqX_I0fL5-yG_vkHvTZlnhSLkktWAg5p34,5712
|
6
6
|
plancraft/data/test.json,sha256=4jWfYMAVuZCFmGB4iZJAjlh9_8jXECdaGp8xn7_tAM4,1317131
|
@@ -1920,7 +1920,7 @@ plancraft/models/generators.py,sha256=F76_iPiqxUjDIrQwF58tzM0bLM91OkZJ0sBqBuki5w
|
|
1920
1920
|
plancraft/models/oracle.py,sha256=f-0KWlBuHy6wcxmDsxM3MQ_QwfBstzfbA26mlk1MgLA,1657
|
1921
1921
|
plancraft/models/utils.py,sha256=E-sZohvolWgGbpHQKgAgkgIfUJoVnT5pMt6JP8xLHKg,4034
|
1922
1922
|
plancraft/train/dataset.py,sha256=oFqEd4LG9oEQ-71teh0Wf7-jJbtybT2ZibfM2bBdBkM,5474
|
1923
|
-
plancraft-0.3.
|
1924
|
-
plancraft-0.3.
|
1925
|
-
plancraft-0.3.
|
1926
|
-
plancraft-0.3.
|
1923
|
+
plancraft-0.3.20.dist-info/METADATA,sha256=7Gs3Txfw2YCBVDAstW7K69fho4vktcYp6sw6g2nQOYE,11148
|
1924
|
+
plancraft-0.3.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
1925
|
+
plancraft-0.3.20.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
|
1926
|
+
plancraft-0.3.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|