plancraft 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plancraft-0.1.2.dist-info/METADATA +74 -0
 - plancraft-0.1.2.dist-info/RECORD +5 -0
 - {plancraft-0.1.1.dist-info → plancraft-0.1.2.dist-info}/WHEEL +1 -1
 - plancraft-0.1.2.dist-info/top_level.txt +1 -0
 - environments/__init__.py +0 -0
 - environments/actions.py +0 -218
 - environments/env_real.py +0 -316
 - environments/env_symbolic.py +0 -212
 - environments/items.py +0 -10
 - environments/planner.py +0 -109
 - environments/recipes.py +0 -542
 - environments/sampler.py +0 -224
 - models/__init__.py +0 -21
 - models/act.py +0 -184
 - models/base.py +0 -152
 - models/bbox_model.py +0 -492
 - models/dummy.py +0 -54
 - models/few_shot_images/__init__.py +0 -16
 - models/generators.py +0 -480
 - models/oam.py +0 -283
 - models/oracle.py +0 -265
 - models/prompts.py +0 -158
 - models/react.py +0 -93
 - models/utils.py +0 -289
 - plancraft-0.1.1.dist-info/METADATA +0 -74
 - plancraft-0.1.1.dist-info/RECORD +0 -26
 - plancraft-0.1.1.dist-info/top_level.txt +0 -3
 - train/dataset.py +0 -187
 - {plancraft-0.1.1.dist-info → plancraft-0.1.2.dist-info}/LICENSE +0 -0
 
    
        environments/sampler.py
    DELETED
    
    | 
         @@ -1,224 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            import math
         
     | 
| 
       2 
     | 
    
         
            -
            import random
         
     | 
| 
       3 
     | 
    
         
            -
            from collections import Counter
         
     | 
| 
       4 
     | 
    
         
            -
             
     | 
| 
       5 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
       6 
     | 
    
         
            -
            from plancraft.environments.items import all_data, ALL_ITEMS
         
     | 
| 
       7 
     | 
    
         
            -
            from plancraft.environments.recipes import RECIPES
         
     | 
| 
       8 
     | 
    
         
            -
            from plancraft.environments.planner import optimal_planner, get_ancestors
         
     | 
| 
       9 
     | 
    
         
            -
             
     | 
| 
       10 
     | 
    
         
            -
             
     | 
| 
       11 
     | 
    
         
            -
            MAX_STACK_SIZE = {}
         
     | 
| 
       12 
     | 
    
         
            -
            for data_item in all_data["items"]:
         
     | 
| 
       13 
     | 
    
         
            -
                if data_item["stackable"]:
         
     | 
| 
       14 
     | 
    
         
            -
                    MAX_STACK_SIZE[data_item["type"]] = data_item["stackSize"]
         
     | 
| 
       15 
     | 
    
         
            -
                else:
         
     | 
| 
       16 
     | 
    
         
            -
                    MAX_STACK_SIZE[data_item["type"]] = 1
         
     | 
| 
       17 
     | 
    
         
            -
             
     | 
| 
       18 
     | 
    
         
            -
             
     | 
| 
       19 
     | 
    
         
            -
            def sample_distractors(
         
     | 
| 
       20 
     | 
    
         
            -
                exclude_set: set = None, num_distractors: int = 16
         
     | 
| 
       21 
     | 
    
         
            -
            ) -> dict[str, int]:
         
     | 
| 
       22 
     | 
    
         
            -
                distractors = {}
         
     | 
| 
       23 
     | 
    
         
            -
                while len(distractors) < num_distractors:
         
     | 
| 
       24 
     | 
    
         
            -
                    item = random.choice(ALL_ITEMS)
         
     | 
| 
       25 
     | 
    
         
            -
                    if exclude_set is not None and item in exclude_set:
         
     | 
| 
       26 
     | 
    
         
            -
                        continue
         
     | 
| 
       27 
     | 
    
         
            -
                    count = random.randint(1, MAX_STACK_SIZE[item])
         
     | 
| 
       28 
     | 
    
         
            -
                    distractors[item] = count
         
     | 
| 
       29 
     | 
    
         
            -
                return distractors
         
     | 
| 
       30 
     | 
    
         
            -
             
     | 
| 
       31 
     | 
    
         
            -
             
     | 
| 
       32 
     | 
    
         
            -
            def assign_to_slots(inventory: dict[str, int]) -> list[dict]:
         
     | 
| 
       33 
     | 
    
         
            -
                # slots available outside of crafting interface
         
     | 
| 
       34 
     | 
    
         
            -
                available_slots = list(range(10, 46))
         
     | 
| 
       35 
     | 
    
         
            -
                random.shuffle(available_slots)
         
     | 
| 
       36 
     | 
    
         
            -
                inventory_list = []
         
     | 
| 
       37 
     | 
    
         
            -
             
     | 
| 
       38 
     | 
    
         
            -
                for item, total_count in inventory.items():
         
     | 
| 
       39 
     | 
    
         
            -
                    while total_count > 0:
         
     | 
| 
       40 
     | 
    
         
            -
                        if len(available_slots) == 0:
         
     | 
| 
       41 
     | 
    
         
            -
                            print("Not enough slots available")
         
     | 
| 
       42 
     | 
    
         
            -
                            break
         
     | 
| 
       43 
     | 
    
         
            -
                        slot = available_slots.pop()
         
     | 
| 
       44 
     | 
    
         
            -
                        count_in_slot = min(total_count, MAX_STACK_SIZE[item])
         
     | 
| 
       45 
     | 
    
         
            -
                        inventory_list.append({"slot": slot, "item": item, "count": count_in_slot})
         
     | 
| 
       46 
     | 
    
         
            -
                        total_count -= count_in_slot
         
     | 
| 
       47 
     | 
    
         
            -
             
     | 
| 
       48 
     | 
    
         
            -
                return inventory_list
         
     | 
| 
       49 
     | 
    
         
            -
             
     | 
| 
       50 
     | 
    
         
            -
             
     | 
| 
       51 
     | 
    
         
            -
            def sample_recipes(
         
     | 
| 
       52 
     | 
    
         
            -
                target: str,
         
     | 
| 
       53 
     | 
    
         
            -
                overall_exclude_set: set,
         
     | 
| 
       54 
     | 
    
         
            -
                target_count: int = 1,
         
     | 
| 
       55 
     | 
    
         
            -
                current_depth=0,
         
     | 
| 
       56 
     | 
    
         
            -
                max_depth=20,
         
     | 
| 
       57 
     | 
    
         
            -
            ) -> tuple[set, set]:
         
     | 
| 
       58 
     | 
    
         
            -
                # stop if the depth is too high
         
     | 
| 
       59 
     | 
    
         
            -
                if current_depth > max_depth:
         
     | 
| 
       60 
     | 
    
         
            -
                    return {}, overall_exclude_set
         
     | 
| 
       61 
     | 
    
         
            -
             
     | 
| 
       62 
     | 
    
         
            -
                # get all the recipes that can craft the target
         
     | 
| 
       63 
     | 
    
         
            -
                overall_exclude_set.update([target])
         
     | 
| 
       64 
     | 
    
         
            -
                local_exclude_set = set()
         
     | 
| 
       65 
     | 
    
         
            -
                random_recipes = []
         
     | 
| 
       66 
     | 
    
         
            -
                for r in RECIPES[target]:
         
     | 
| 
       67 
     | 
    
         
            -
                    recipe_inputs, exclude_set = r.sample_inputs()
         
     | 
| 
       68 
     | 
    
         
            -
                    # if inputs are already in the exclude set, skip this recipe (ensures no cycle)
         
     | 
| 
       69 
     | 
    
         
            -
                    if exclude_set.intersection(overall_exclude_set):
         
     | 
| 
       70 
     | 
    
         
            -
                        return {}, overall_exclude_set
         
     | 
| 
       71 
     | 
    
         
            -
                    local_exclude_set.update(exclude_set)
         
     | 
| 
       72 
     | 
    
         
            -
                    random_recipes.append((r, recipe_inputs))
         
     | 
| 
       73 
     | 
    
         
            -
             
     | 
| 
       74 
     | 
    
         
            -
                overall_exclude_set |= local_exclude_set
         
     | 
| 
       75 
     | 
    
         
            -
             
     | 
| 
       76 
     | 
    
         
            -
                # no recipes found
         
     | 
| 
       77 
     | 
    
         
            -
                if len(random_recipes) == 0:
         
     | 
| 
       78 
     | 
    
         
            -
                    return {}, overall_exclude_set
         
     | 
| 
       79 
     | 
    
         
            -
             
     | 
| 
       80 
     | 
    
         
            -
                # sample a random recipe
         
     | 
| 
       81 
     | 
    
         
            -
                random_recipe = random.choice(random_recipes)
         
     | 
| 
       82 
     | 
    
         
            -
                recipe, start_inputs = random_recipe
         
     | 
| 
       83 
     | 
    
         
            -
             
     | 
| 
       84 
     | 
    
         
            -
                # recipe will not produce enough
         
     | 
| 
       85 
     | 
    
         
            -
                if recipe.result.count < target_count:
         
     | 
| 
       86 
     | 
    
         
            -
                    # must do recipe X times
         
     | 
| 
       87 
     | 
    
         
            -
                    recipe_multiplier = math.ceil(target_count / recipe.result.count)
         
     | 
| 
       88 
     | 
    
         
            -
                    start_inputs = {k: v * recipe_multiplier for k, v in start_inputs.items()}
         
     | 
| 
       89 
     | 
    
         
            -
             
     | 
| 
       90 
     | 
    
         
            -
                for input_item in list(start_inputs.keys()):
         
     | 
| 
       91 
     | 
    
         
            -
                    # randomize depth first search to end early
         
     | 
| 
       92 
     | 
    
         
            -
                    if random.choice([True, False]):
         
     | 
| 
       93 
     | 
    
         
            -
                        continue
         
     | 
| 
       94 
     | 
    
         
            -
             
     | 
| 
       95 
     | 
    
         
            -
                    children_recipe_inputs, updated_exclude_set = sample_recipes(
         
     | 
| 
       96 
     | 
    
         
            -
                        target=input_item,
         
     | 
| 
       97 
     | 
    
         
            -
                        overall_exclude_set=overall_exclude_set,
         
     | 
| 
       98 
     | 
    
         
            -
                        target_count=start_inputs[input_item],
         
     | 
| 
       99 
     | 
    
         
            -
                        current_depth=current_depth + 1,
         
     | 
| 
       100 
     | 
    
         
            -
                    )
         
     | 
| 
       101 
     | 
    
         
            -
                    if len(children_recipe_inputs) == 0:
         
     | 
| 
       102 
     | 
    
         
            -
                        continue
         
     | 
| 
       103 
     | 
    
         
            -
             
     | 
| 
       104 
     | 
    
         
            -
                    overall_exclude_set.update(updated_exclude_set)
         
     | 
| 
       105 
     | 
    
         
            -
             
     | 
| 
       106 
     | 
    
         
            -
                    # remove recipe input item since we are crafting it
         
     | 
| 
       107 
     | 
    
         
            -
                    start_inputs[input_item] = 0
         
     | 
| 
       108 
     | 
    
         
            -
             
     | 
| 
       109 
     | 
    
         
            -
                    # add the children recipe inputs
         
     | 
| 
       110 
     | 
    
         
            -
                    for item, count in children_recipe_inputs.items():
         
     | 
| 
       111 
     | 
    
         
            -
                        start_inputs[item] = start_inputs.get(item, 0) + count
         
     | 
| 
       112 
     | 
    
         
            -
             
     | 
| 
       113 
     | 
    
         
            -
                overall_exclude_set = overall_exclude_set - {None}
         
     | 
| 
       114 
     | 
    
         
            -
                start_inputs = {k: v for k, v in start_inputs.items() if v > 0}
         
     | 
| 
       115 
     | 
    
         
            -
             
     | 
| 
       116 
     | 
    
         
            -
                return start_inputs, overall_exclude_set
         
     | 
| 
       117 
     | 
    
         
            -
             
     | 
| 
       118 
     | 
    
         
            -
             
     | 
| 
       119 
     | 
    
         
            -
            def remove_ancestor_items(target: str, inventory: dict[str, int]) -> dict[str, int]:
         
     | 
| 
       120 
     | 
    
         
            -
                ancestors = set(get_ancestors(target))
         
     | 
| 
       121 
     | 
    
         
            -
                possible_items = set(inventory.keys())
         
     | 
| 
       122 
     | 
    
         
            -
                items_to_remove = list(ancestors.intersection(possible_items))
         
     | 
| 
       123 
     | 
    
         
            -
                num_items = random.randint(1, len(items_to_remove))
         
     | 
| 
       124 
     | 
    
         
            -
                for item in random.sample(items_to_remove, num_items):
         
     | 
| 
       125 
     | 
    
         
            -
                    count_to_remove = random.randint(1, inventory[item])
         
     | 
| 
       126 
     | 
    
         
            -
                    inventory[item] -= count_to_remove
         
     | 
| 
       127 
     | 
    
         
            -
                    if inventory[item] == 0:
         
     | 
| 
       128 
     | 
    
         
            -
                        del inventory[item]
         
     | 
| 
       129 
     | 
    
         
            -
                return inventory
         
     | 
| 
       130 
     | 
    
         
            -
             
     | 
| 
       131 
     | 
    
         
            -
             
     | 
| 
       132 
     | 
    
         
            -
            def construct_example(
         
     | 
| 
       133 
     | 
    
         
            -
                target: str,
         
     | 
| 
       134 
     | 
    
         
            -
                num_distractors: 16,
         
     | 
| 
       135 
     | 
    
         
            -
                impossible=False,
         
     | 
| 
       136 
     | 
    
         
            -
            ) -> list[dict]:
         
     | 
| 
       137 
     | 
    
         
            -
                """
         
     | 
| 
       138 
     | 
    
         
            -
                For a given target object, number of distractors, and impossible flag
         
     | 
| 
       139 
     | 
    
         
            -
                Return a dictionary with the start inventory for the crafting task
         
     | 
| 
       140 
     | 
    
         
            -
             
     | 
| 
       141 
     | 
    
         
            -
                The crafting task should be to craft the target, the inventory should contain
         
     | 
| 
       142 
     | 
    
         
            -
                the resources required for the recipe to be crafted.
         
     | 
| 
       143 
     | 
    
         
            -
             
     | 
| 
       144 
     | 
    
         
            -
                The number of distractors are how many random items should be added to the inventory.
         
     | 
| 
       145 
     | 
    
         
            -
             
     | 
| 
       146 
     | 
    
         
            -
                If impossible is True, the target item should not be craftable with the given inventory.
         
     | 
| 
       147 
     | 
    
         
            -
                """
         
     | 
| 
       148 
     | 
    
         
            -
             
     | 
| 
       149 
     | 
    
         
            -
                # sample the recipe
         
     | 
| 
       150 
     | 
    
         
            -
                inventory, overall_exclude_set = sample_recipes(target, set())
         
     | 
| 
       151 
     | 
    
         
            -
                if impossible:
         
     | 
| 
       152 
     | 
    
         
            -
                    # if impossible then remove one or more items from the inventory
         
     | 
| 
       153 
     | 
    
         
            -
                    inventory = remove_ancestor_items(
         
     | 
| 
       154 
     | 
    
         
            -
                        target,
         
     | 
| 
       155 
     | 
    
         
            -
                        inventory,
         
     | 
| 
       156 
     | 
    
         
            -
                    )
         
     | 
| 
       157 
     | 
    
         
            -
             
     | 
| 
       158 
     | 
    
         
            -
                # add distractors to the inventory
         
     | 
| 
       159 
     | 
    
         
            -
                distractors = sample_distractors(overall_exclude_set, num_distractors)
         
     | 
| 
       160 
     | 
    
         
            -
                inventory.update(distractors)
         
     | 
| 
       161 
     | 
    
         
            -
             
     | 
| 
       162 
     | 
    
         
            -
                optimal_path = optimal_planner(target, inventory)
         
     | 
| 
       163 
     | 
    
         
            -
                # @TODO this is a hack to ensure that we don't have impossible examples
         
     | 
| 
       164 
     | 
    
         
            -
                while optimal_path is not None and impossible:
         
     | 
| 
       165 
     | 
    
         
            -
                    inventory = remove_ancestor_items(target, inventory)
         
     | 
| 
       166 
     | 
    
         
            -
                    optimal_path = optimal_planner(target, inventory)
         
     | 
| 
       167 
     | 
    
         
            -
             
     | 
| 
       168 
     | 
    
         
            -
                # assign to slots
         
     | 
| 
       169 
     | 
    
         
            -
                inventory_list = assign_to_slots(inventory)
         
     | 
| 
       170 
     | 
    
         
            -
                example = {
         
     | 
| 
       171 
     | 
    
         
            -
                    "inventory": inventory,
         
     | 
| 
       172 
     | 
    
         
            -
                    "slotted_inventory": inventory_list,
         
     | 
| 
       173 
     | 
    
         
            -
                    "target": target,
         
     | 
| 
       174 
     | 
    
         
            -
                    "num_distractors": num_distractors,
         
     | 
| 
       175 
     | 
    
         
            -
                    "impossible": impossible,
         
     | 
| 
       176 
     | 
    
         
            -
                }
         
     | 
| 
       177 
     | 
    
         
            -
                # either impossible and no path or not impossible and path exists
         
     | 
| 
       178 
     | 
    
         
            -
                assert (impossible and optimal_path is None) or (
         
     | 
| 
       179 
     | 
    
         
            -
                    not impossible and optimal_path is not None
         
     | 
| 
       180 
     | 
    
         
            -
                )
         
     | 
| 
       181 
     | 
    
         
            -
             
     | 
| 
       182 
     | 
    
         
            -
                if not impossible:
         
     | 
| 
       183 
     | 
    
         
            -
                    example["optimal_path_length"] = len(optimal_path)
         
     | 
| 
       184 
     | 
    
         
            -
                    example["optimal_path"] = [r.result.item for (r, i) in optimal_path]
         
     | 
| 
       185 
     | 
    
         
            -
                    example["inventory_trace"] = [i for (r, i) in optimal_path]
         
     | 
| 
       186 
     | 
    
         
            -
                    items_used, unique_items_used = calculate_stats_from_inventory_trace(
         
     | 
| 
       187 
     | 
    
         
            -
                        [example["inventory"]] + example["inventory_trace"]
         
     | 
| 
       188 
     | 
    
         
            -
                    )
         
     | 
| 
       189 
     | 
    
         
            -
                    example["items_used"] = items_used
         
     | 
| 
       190 
     | 
    
         
            -
                    example["unique_items_used"] = unique_items_used
         
     | 
| 
       191 
     | 
    
         
            -
             
     | 
| 
       192 
     | 
    
         
            -
                return example
         
     | 
| 
       193 
     | 
    
         
            -
             
     | 
| 
       194 
     | 
    
         
            -
             
     | 
| 
       195 
     | 
    
         
            -
            def calculate_stats_from_inventory_trace(
         
     | 
| 
       196 
     | 
    
         
            -
                inventory_trace: list[dict],
         
     | 
| 
       197 
     | 
    
         
            -
            ) -> tuple[int, int]:
         
     | 
| 
       198 
     | 
    
         
            -
                total_items_used = 0
         
     | 
| 
       199 
     | 
    
         
            -
                total_unique_items_used = 0
         
     | 
| 
       200 
     | 
    
         
            -
             
     | 
| 
       201 
     | 
    
         
            -
                for a, b in zip(inventory_trace[:-1], inventory_trace[1:]):
         
     | 
| 
       202 
     | 
    
         
            -
                    diff = Counter(a) - Counter(b)
         
     | 
| 
       203 
     | 
    
         
            -
                    total_items_used += sum(diff.values())
         
     | 
| 
       204 
     | 
    
         
            -
                    total_unique_items_used += len(diff)
         
     | 
| 
       205 
     | 
    
         
            -
             
     | 
| 
       206 
     | 
    
         
            -
                return total_items_used, total_unique_items_used
         
     | 
| 
       207 
     | 
    
         
            -
             
     | 
| 
       208 
     | 
    
         
            -
             
     | 
| 
       209 
     | 
    
         
            -
            def generate_dataset(seed=2024, distractors=[4, 8, 16], num_examples=10):
         
     | 
| 
       210 
     | 
    
         
            -
                random.seed(seed)
         
     | 
| 
       211 
     | 
    
         
            -
                np.random.seed(seed)
         
     | 
| 
       212 
     | 
    
         
            -
             
     | 
| 
       213 
     | 
    
         
            -
                dataset = []
         
     | 
| 
       214 
     | 
    
         
            -
                for recipe_target in list(RECIPES.keys()):
         
     | 
| 
       215 
     | 
    
         
            -
                    if len(RECIPES[recipe_target]) == 0:
         
     | 
| 
       216 
     | 
    
         
            -
                        continue
         
     | 
| 
       217 
     | 
    
         
            -
                    for num_distractors in distractors:
         
     | 
| 
       218 
     | 
    
         
            -
                        for _ in range(num_examples):
         
     | 
| 
       219 
     | 
    
         
            -
                            example = construct_example(
         
     | 
| 
       220 
     | 
    
         
            -
                                target=recipe_target, num_distractors=num_distractors
         
     | 
| 
       221 
     | 
    
         
            -
                            )
         
     | 
| 
       222 
     | 
    
         
            -
                            dataset.append(example)
         
     | 
| 
       223 
     | 
    
         
            -
             
     | 
| 
       224 
     | 
    
         
            -
                return dataset
         
     | 
    
        models/__init__.py
    DELETED
    
    | 
         @@ -1,21 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            from plancraft.models.base import ABCModel
         
     | 
| 
       2 
     | 
    
         
            -
             
     | 
| 
       3 
     | 
    
         
            -
            from plancraft.config import EvalConfig
         
     | 
| 
       4 
     | 
    
         
            -
            from plancraft.models.dummy import DummyModel
         
     | 
| 
       5 
     | 
    
         
            -
            from plancraft.models.react import ReactModel
         
     | 
| 
       6 
     | 
    
         
            -
            from plancraft.models.oracle import OracleModel
         
     | 
| 
       7 
     | 
    
         
            -
            from plancraft.models.act import ActModel
         
     | 
| 
       8 
     | 
    
         
            -
             
     | 
| 
       9 
     | 
    
         
            -
             
     | 
| 
       10 
     | 
    
         
            -
            def get_model(cfg: EvalConfig) -> ABCModel:
         
     | 
| 
       11 
     | 
    
         
            -
                """
         
     | 
| 
       12 
     | 
    
         
            -
                Factory get model (default: ReactModel)
         
     | 
| 
       13 
     | 
    
         
            -
                """
         
     | 
| 
       14 
     | 
    
         
            -
                if cfg.plancraft.mode == "dummy":
         
     | 
| 
       15 
     | 
    
         
            -
                    return DummyModel(cfg)
         
     | 
| 
       16 
     | 
    
         
            -
                elif cfg.plancraft.mode == "oracle":
         
     | 
| 
       17 
     | 
    
         
            -
                    return OracleModel(cfg)
         
     | 
| 
       18 
     | 
    
         
            -
                elif cfg.plancraft.mode == "act":
         
     | 
| 
       19 
     | 
    
         
            -
                    return ActModel(cfg)
         
     | 
| 
       20 
     | 
    
         
            -
                else:
         
     | 
| 
       21 
     | 
    
         
            -
                    return ReactModel(cfg)
         
     | 
    
        models/act.py
    DELETED
    
    | 
         @@ -1,184 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            import copy
         
     | 
| 
       2 
     | 
    
         
            -
            import torch
         
     | 
| 
       3 
     | 
    
         
            -
            from dotenv import load_dotenv
         
     | 
| 
       4 
     | 
    
         
            -
             
     | 
| 
       5 
     | 
    
         
            -
            from plancraft.config import EvalConfig
         
     | 
| 
       6 
     | 
    
         
            -
            from plancraft.environments.actions import (
         
     | 
| 
       7 
     | 
    
         
            -
                NoOp,
         
     | 
| 
       8 
     | 
    
         
            -
                StopAction,
         
     | 
| 
       9 
     | 
    
         
            -
                SymbolicAction,
         
     | 
| 
       10 
     | 
    
         
            -
            )
         
     | 
| 
       11 
     | 
    
         
            -
            from plancraft.models.base import ABCModel, History
         
     | 
| 
       12 
     | 
    
         
            -
            from plancraft.models.bbox_model import IntegratedBoundingBoxModel
         
     | 
| 
       13 
     | 
    
         
            -
            from plancraft.models.few_shot_images import load_prompt_images
         
     | 
| 
       14 
     | 
    
         
            -
            from plancraft.models.generators import (
         
     | 
| 
       15 
     | 
    
         
            -
                OAMGenerator,
         
     | 
| 
       16 
     | 
    
         
            -
                OpenAIGenerator,
         
     | 
| 
       17 
     | 
    
         
            -
                TransformersGenerator,
         
     | 
| 
       18 
     | 
    
         
            -
            )
         
     | 
| 
       19 
     | 
    
         
            -
            from plancraft.models.prompts import get_prompt_example, get_system_prompt
         
     | 
| 
       20 
     | 
    
         
            -
            from plancraft.models.utils import (
         
     | 
| 
       21 
     | 
    
         
            -
                convert_observation_to_message,
         
     | 
| 
       22 
     | 
    
         
            -
                parse_content_response,
         
     | 
| 
       23 
     | 
    
         
            -
            )
         
     | 
| 
       24 
     | 
    
         
            -
             
     | 
| 
       25 
     | 
    
         
            -
             
     | 
| 
       26 
     | 
    
         
            -
            load_dotenv()
         
     | 
| 
       27 
     | 
    
         
            -
             
     | 
| 
       28 
     | 
    
         
            -
             
     | 
| 
       29 
     | 
    
         
            -
            class ActModel(ABCModel):
         
     | 
| 
       30 
     | 
    
         
            -
                """
         
     | 
| 
       31 
     | 
    
         
            -
                Model that does action without thinking step
         
     | 
| 
       32 
     | 
    
         
            -
                """
         
     | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
       34 
     | 
    
         
            -
                def __init__(self, cfg: EvalConfig):
         
     | 
| 
       35 
     | 
    
         
            -
                    assert (
         
     | 
| 
       36 
     | 
    
         
            -
                        cfg.plancraft.environment.symbolic_action_space
         
     | 
| 
       37 
     | 
    
         
            -
                    ), "Real action space unsupported"
         
     | 
| 
       38 
     | 
    
         
            -
                    self.cfg = cfg
         
     | 
| 
       39 
     | 
    
         
            -
                    self.env_is_multimodal = not cfg.plancraft.environment.symbolic
         
     | 
| 
       40 
     | 
    
         
            -
                    self.use_maskrcnn = cfg.plancraft.use_maskrcnn
         
     | 
| 
       41 
     | 
    
         
            -
                    self.use_multimodal_content_format = cfg.plancraft.use_multimodal_content_format
         
     | 
| 
       42 
     | 
    
         
            -
                    self.use_text_inventory = cfg.plancraft.use_text_inventory
         
     | 
| 
       43 
     | 
    
         
            -
                    self.use_images = cfg.plancraft.use_images
         
     | 
| 
       44 
     | 
    
         
            -
             
     | 
| 
       45 
     | 
    
         
            -
                    self.bbox_model = None
         
     | 
| 
       46 
     | 
    
         
            -
                    if self.use_maskrcnn:
         
     | 
| 
       47 
     | 
    
         
            -
                        assert self.env_is_multimodal, "MaskRCNN only supported in multimodal mode"
         
     | 
| 
       48 
     | 
    
         
            -
                        self.bbox_model = IntegratedBoundingBoxModel.from_pretrained(
         
     | 
| 
       49 
     | 
    
         
            -
                            "gautierdag/plancraft-maskrcnn"
         
     | 
| 
       50 
     | 
    
         
            -
                        )
         
     | 
| 
       51 
     | 
    
         
            -
                        self.bbox_model.eval()
         
     | 
| 
       52 
     | 
    
         
            -
                        if torch.cuda.is_available():
         
     | 
| 
       53 
     | 
    
         
            -
                            self.bbox_model.cuda()
         
     | 
| 
       54 
     | 
    
         
            -
                        # MaskRCNN is not multimodal model but a separate model
         
     | 
| 
       55 
     | 
    
         
            -
             
     | 
| 
       56 
     | 
    
         
            -
                    self.few_shot = cfg.plancraft.few_shot
         
     | 
| 
       57 
     | 
    
         
            -
                    self.use_system_prompt = cfg.plancraft.system_prompt
         
     | 
| 
       58 
     | 
    
         
            -
                    self.max_invalid_actions = 3
         
     | 
| 
       59 
     | 
    
         
            -
             
     | 
| 
       60 
     | 
    
         
            -
                    # underlying language model
         
     | 
| 
       61 
     | 
    
         
            -
                    if "gpt-4o" in cfg.plancraft.model:
         
     | 
| 
       62 
     | 
    
         
            -
                        self.use_multimodal_content_format = True
         
     | 
| 
       63 
     | 
    
         
            -
                        self.llm = OpenAIGenerator(
         
     | 
| 
       64 
     | 
    
         
            -
                            use_images=self.use_images, model_name=cfg.plancraft.model
         
     | 
| 
       65 
     | 
    
         
            -
                        )
         
     | 
| 
       66 
     | 
    
         
            -
                    elif "oam" in cfg.plancraft.model:
         
     | 
| 
       67 
     | 
    
         
            -
                        self.llm = OAMGenerator(model_name=cfg.plancraft.model)
         
     | 
| 
       68 
     | 
    
         
            -
                    else:
         
     | 
| 
       69 
     | 
    
         
            -
                        # model is transformers based
         
     | 
| 
       70 
     | 
    
         
            -
                        self.llm = TransformersGenerator(
         
     | 
| 
       71 
     | 
    
         
            -
                            model_name=cfg.plancraft.model,
         
     | 
| 
       72 
     | 
    
         
            -
                            tokenizer_name=cfg.plancraft.tokenizer,
         
     | 
| 
       73 
     | 
    
         
            -
                            quantize=cfg.plancraft.quantize,
         
     | 
| 
       74 
     | 
    
         
            -
                            use_hot_cache=cfg.plancraft.hot_cache,
         
     | 
| 
       75 
     | 
    
         
            -
                            adapter_name=cfg.plancraft.adapter,
         
     | 
| 
       76 
     | 
    
         
            -
                        )
         
     | 
| 
       77 
     | 
    
         
            -
             
     | 
| 
       78 
     | 
    
         
            -
                    self.prompt_images = []
         
     | 
| 
       79 
     | 
    
         
            -
             
     | 
| 
       80 
     | 
    
         
            -
                    self.valid_actions = cfg.plancraft.valid_actions
         
     | 
| 
       81 
     | 
    
         
            -
                    self.system_prompt_text = get_system_prompt(self.valid_actions)
         
     | 
| 
       82 
     | 
    
         
            -
             
     | 
| 
       83 
     | 
    
         
            -
                    examples = get_prompt_example(
         
     | 
| 
       84 
     | 
    
         
            -
                        self.valid_actions,
         
     | 
| 
       85 
     | 
    
         
            -
                        use_text_inventory=self.use_text_inventory,
         
     | 
| 
       86 
     | 
    
         
            -
                        use_multimodal_content_format=self.use_multimodal_content_format,
         
     | 
| 
       87 
     | 
    
         
            -
                        use_images=self.use_images,
         
     | 
| 
       88 
     | 
    
         
            -
                    )
         
     | 
| 
       89 
     | 
    
         
            -
                    if self.env_is_multimodal and self.use_images:
         
     | 
| 
       90 
     | 
    
         
            -
                        self.prompt_images = load_prompt_images()
         
     | 
| 
       91 
     | 
    
         
            -
             
     | 
| 
       92 
     | 
    
         
            -
                    if self.use_multimodal_content_format:
         
     | 
| 
       93 
     | 
    
         
            -
                        self.system_prompt = {
         
     | 
| 
       94 
     | 
    
         
            -
                            "role": "system",
         
     | 
| 
       95 
     | 
    
         
            -
                            "content": [
         
     | 
| 
       96 
     | 
    
         
            -
                                {"text": copy.deepcopy(self.system_prompt_text), "type": "text"}
         
     | 
| 
       97 
     | 
    
         
            -
                            ],
         
     | 
| 
       98 
     | 
    
         
            -
                        }
         
     | 
| 
       99 
     | 
    
         
            -
                    else:
         
     | 
| 
       100 
     | 
    
         
            -
                        self.system_prompt = {
         
     | 
| 
       101 
     | 
    
         
            -
                            "role": "system",
         
     | 
| 
       102 
     | 
    
         
            -
                            "content": copy.deepcopy(self.system_prompt_text),
         
     | 
| 
       103 
     | 
    
         
            -
                        }
         
     | 
| 
       104 
     | 
    
         
            -
             
     | 
| 
       105 
     | 
    
         
            -
                    if not self.few_shot:
         
     | 
| 
       106 
     | 
    
         
            -
                        examples = []
         
     | 
| 
       107 
     | 
    
         
            -
                    if not self.use_system_prompt:
         
     | 
| 
       108 
     | 
    
         
            -
                        self.system_prompt = None
         
     | 
| 
       109 
     | 
    
         
            -
             
     | 
| 
       110 
     | 
    
         
            -
                    self.history = History(
         
     | 
| 
       111 
     | 
    
         
            -
                        initial_dialogue=examples,
         
     | 
| 
       112 
     | 
    
         
            -
                        use_multimodal_content_format=self.use_multimodal_content_format,
         
     | 
| 
       113 
     | 
    
         
            -
                    )
         
     | 
| 
       114 
     | 
    
         
            -
             
     | 
| 
       115 
     | 
    
         
            -
                    self.max_messages_window = cfg.plancraft.max_message_window
         
     | 
| 
       116 
     | 
    
         
            -
                    self.kv_cache = None
         
     | 
| 
       117 
     | 
    
         
            -
             
     | 
| 
       118 
     | 
    
         
            -
                def reset_history(
         
     | 
| 
       119 
     | 
    
         
            -
                    self,
         
     | 
| 
       120 
     | 
    
         
            -
                    objective: str,
         
     | 
| 
       121 
     | 
    
         
            -
                ):
         
     | 
| 
       122 
     | 
    
         
            -
                    examples = []
         
     | 
| 
       123 
     | 
    
         
            -
                    if self.few_shot:
         
     | 
| 
       124 
     | 
    
         
            -
                        examples = get_prompt_example(
         
     | 
| 
       125 
     | 
    
         
            -
                            self.valid_actions,
         
     | 
| 
       126 
     | 
    
         
            -
                            use_text_inventory=self.use_text_inventory,
         
     | 
| 
       127 
     | 
    
         
            -
                            use_multimodal_content_format=self.use_multimodal_content_format,
         
     | 
| 
       128 
     | 
    
         
            -
                            use_images=self.use_images,
         
     | 
| 
       129 
     | 
    
         
            -
                        )
         
     | 
| 
       130 
     | 
    
         
            -
             
     | 
| 
       131 
     | 
    
         
            -
                    self.history.reset(objective=objective, initial_dialogue=examples)
         
     | 
| 
       132 
     | 
    
         
            -
                    self.llm.reset()
         
     | 
| 
       133 
     | 
    
         
            -
             
     | 
| 
       134 
     | 
    
         
            -
                def step(self, observation: dict) -> SymbolicAction | StopAction:
         
     | 
| 
       135 
     | 
    
         
            -
                    self.history.add_observation_to_history(observation)
         
     | 
| 
       136 
     | 
    
         
            -
             
     | 
| 
       137 
     | 
    
         
            -
                    # add observation to history
         
     | 
| 
       138 
     | 
    
         
            -
                    observation_message = convert_observation_to_message(
         
     | 
| 
       139 
     | 
    
         
            -
                        observation,
         
     | 
| 
       140 
     | 
    
         
            -
                        objective=self.history.objective,
         
     | 
| 
       141 
     | 
    
         
            -
                        bbox_model=self.bbox_model,
         
     | 
| 
       142 
     | 
    
         
            -
                        oam_model="oam" in self.llm.model_name,
         
     | 
| 
       143 
     | 
    
         
            -
                        use_text_inventory=self.use_text_inventory,
         
     | 
| 
       144 
     | 
    
         
            -
                        use_multimodal_content_format=self.use_multimodal_content_format,
         
     | 
| 
       145 
     | 
    
         
            -
                        use_images=self.use_images,
         
     | 
| 
       146 
     | 
    
         
            -
                    )
         
     | 
| 
       147 
     | 
    
         
            -
                    self.history.add_message_to_history(content=observation_message, role="user")
         
     | 
| 
       148 
     | 
    
         
            -
             
     | 
| 
       149 
     | 
    
         
            -
                    # Iterate until valid action
         
     | 
| 
       150 
     | 
    
         
            -
                    i = 0
         
     | 
| 
       151 
     | 
    
         
            -
                    while i < self.max_invalid_actions:
         
     | 
| 
       152 
     | 
    
         
            -
                        # add observation to history
         
     | 
| 
       153 
     | 
    
         
            -
                        message_window, image_window = self.llm.prepare_messages(
         
     | 
| 
       154 
     | 
    
         
            -
                            history=self.history,
         
     | 
| 
       155 
     | 
    
         
            -
                            max_messages_window=self.max_messages_window,
         
     | 
| 
       156 
     | 
    
         
            -
                            system_prompt=self.system_prompt,
         
     | 
| 
       157 
     | 
    
         
            -
                            prompt_images=self.prompt_images,
         
     | 
| 
       158 
     | 
    
         
            -
                        )
         
     | 
| 
       159 
     | 
    
         
            -
                        action_messages, action_token_used = self.llm.generate_unconstrained(
         
     | 
| 
       160 
     | 
    
         
            -
                            batch_messages=[message_window],
         
     | 
| 
       161 
     | 
    
         
            -
                            images=[image_window],
         
     | 
| 
       162 
     | 
    
         
            -
                        )
         
     | 
| 
       163 
     | 
    
         
            -
                        self.history.tokens_used += action_token_used
         
     | 
| 
       164 
     | 
    
         
            -
             
     | 
| 
       165 
     | 
    
         
            -
                        action_message = action_messages[0].split("\n")[0].strip()
         
     | 
| 
       166 
     | 
    
         
            -
             
     | 
| 
       167 
     | 
    
         
            -
                        self.history.add_message_to_history(
         
     | 
| 
       168 
     | 
    
         
            -
                            content=action_message, role="assistant"
         
     | 
| 
       169 
     | 
    
         
            -
                        )
         
     | 
| 
       170 
     | 
    
         
            -
                        response = parse_content_response(
         
     | 
| 
       171 
     | 
    
         
            -
                            action_message, valid_actions=self.valid_actions
         
     | 
| 
       172 
     | 
    
         
            -
                        )
         
     | 
| 
       173 
     | 
    
         
            -
                        if not isinstance(response, str):
         
     | 
| 
       174 
     | 
    
         
            -
                            # valid action
         
     | 
| 
       175 
     | 
    
         
            -
                            self.history.add_action_to_history(response)
         
     | 
| 
       176 
     | 
    
         
            -
                            return response
         
     | 
| 
       177 
     | 
    
         
            -
             
     | 
| 
       178 
     | 
    
         
            -
                        self.history.add_message_to_history(
         
     | 
| 
       179 
     | 
    
         
            -
                            content=response,
         
     | 
| 
       180 
     | 
    
         
            -
                        )
         
     | 
| 
       181 
     | 
    
         
            -
                        i += 1
         
     | 
| 
       182 
     | 
    
         
            -
             
     | 
| 
       183 
     | 
    
         
            -
                    # if no action is found after max_invalid_actions, default to useless move action
         
     | 
| 
       184 
     | 
    
         
            -
                    return NoOp()
         
     | 
    
        models/base.py
    DELETED
    
    | 
         @@ -1,152 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            import abc
         
     | 
| 
       2 
     | 
    
         
            -
             
     | 
| 
       3 
     | 
    
         
            -
            from copy import copy
         
     | 
| 
       4 
     | 
    
         
            -
            from collections import Counter
         
     | 
| 
       5 
     | 
    
         
            -
             
     | 
| 
       6 
     | 
    
         
            -
            from plancraft.environments.actions import (
         
     | 
| 
       7 
     | 
    
         
            -
                SymbolicMoveAction,
         
     | 
| 
       8 
     | 
    
         
            -
                RealActionInteraction,
         
     | 
| 
       9 
     | 
    
         
            -
                SymbolicSmeltAction,
         
     | 
| 
       10 
     | 
    
         
            -
            )
         
     | 
| 
       11 
     | 
    
         
            -
             
     | 
| 
       12 
     | 
    
         
            -
             
     | 
| 
       13 
     | 
    
         
            -
            class History:
         
     | 
| 
       14 
     | 
    
         
            -
                def __init__(
         
     | 
| 
       15 
     | 
    
         
            -
                    self,
         
     | 
| 
       16 
     | 
    
         
            -
                    objective: str = "",
         
     | 
| 
       17 
     | 
    
         
            -
                    initial_dialogue: list[dict] = [],
         
     | 
| 
       18 
     | 
    
         
            -
                    use_multimodal_content_format=False,
         
     | 
| 
       19 
     | 
    
         
            -
                ):
         
     | 
| 
       20 
     | 
    
         
            -
                    self.dialogue_history = initial_dialogue
         
     | 
| 
       21 
     | 
    
         
            -
                    self.initial_dialogue_length = len(initial_dialogue)
         
     | 
| 
       22 
     | 
    
         
            -
                    self.action_history = []
         
     | 
| 
       23 
     | 
    
         
            -
                    self.inventory_history = []
         
     | 
| 
       24 
     | 
    
         
            -
                    self.inventory_counters = []
         
     | 
| 
       25 
     | 
    
         
            -
                    self.images = []
         
     | 
| 
       26 
     | 
    
         
            -
                    self.objective = objective
         
     | 
| 
       27 
     | 
    
         
            -
                    self.tokens_used = 0
         
     | 
| 
       28 
     | 
    
         
            -
                    self.use_multimodal_content_format = use_multimodal_content_format
         
     | 
| 
       29 
     | 
    
         
            -
             
     | 
| 
       30 
     | 
    
         
            -
                def add_message_to_history(self, content: str | dict, role="user"):
         
     | 
| 
       31 
     | 
    
         
            -
                    if role == "assistant":
         
     | 
| 
       32 
     | 
    
         
            -
                        print(content)
         
     | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
       34 
     | 
    
         
            -
                    if isinstance(content, dict):
         
     | 
| 
       35 
     | 
    
         
            -
                        assert "content" in content, "content key not found in message"
         
     | 
| 
       36 
     | 
    
         
            -
                        content["role"] = role
         
     | 
| 
       37 
     | 
    
         
            -
                        self.dialogue_history.append(content)
         
     | 
| 
       38 
     | 
    
         
            -
                    else:
         
     | 
| 
       39 
     | 
    
         
            -
                        # fix for listed content type
         
     | 
| 
       40 
     | 
    
         
            -
                        if self.use_multimodal_content_format:
         
     | 
| 
       41 
     | 
    
         
            -
                            return self.add_message_to_history(
         
     | 
| 
       42 
     | 
    
         
            -
                                content={
         
     | 
| 
       43 
     | 
    
         
            -
                                    "content": [{"type": "text", "text": content}],
         
     | 
| 
       44 
     | 
    
         
            -
                                    "role": role,
         
     | 
| 
       45 
     | 
    
         
            -
                                },
         
     | 
| 
       46 
     | 
    
         
            -
                                role=role,
         
     | 
| 
       47 
     | 
    
         
            -
                            )
         
     | 
| 
       48 
     | 
    
         
            -
                        else:
         
     | 
| 
       49 
     | 
    
         
            -
                            self.dialogue_history.append({"role": role, "content": content})
         
     | 
| 
       50 
     | 
    
         
            -
             
     | 
| 
       51 
     | 
    
         
            -
                def add_action_to_history(
         
     | 
| 
       52 
     | 
    
         
            -
                    self, action: SymbolicSmeltAction | RealActionInteraction | SymbolicMoveAction
         
     | 
| 
       53 
     | 
    
         
            -
                ):
         
     | 
| 
       54 
     | 
    
         
            -
                    if action is None:
         
     | 
| 
       55 
     | 
    
         
            -
                        return
         
     | 
| 
       56 
     | 
    
         
            -
                    self.action_history.append(action.model_dump())
         
     | 
| 
       57 
     | 
    
         
            -
             
     | 
| 
       58 
     | 
    
         
            -
                def add_inventory_to_history(self, inventory: list[dict[str, int]]):
         
     | 
| 
       59 
     | 
    
         
            -
                    self.inventory_history.append(inventory)
         
     | 
| 
       60 
     | 
    
         
            -
             
     | 
| 
       61 
     | 
    
         
            -
                    # count inventory
         
     | 
| 
       62 
     | 
    
         
            -
                    counter = Counter()
         
     | 
| 
       63 
     | 
    
         
            -
                    for item in inventory:
         
     | 
| 
       64 
     | 
    
         
            -
                        # ignore slot 0
         
     | 
| 
       65 
     | 
    
         
            -
                        if "slot" in item and item["slot"] == 0:
         
     | 
| 
       66 
     | 
    
         
            -
                            continue
         
     | 
| 
       67 
     | 
    
         
            -
                        if "index" in item and item["index"] == 0:
         
     | 
| 
       68 
     | 
    
         
            -
                            continue
         
     | 
| 
       69 
     | 
    
         
            -
                        counter[item["type"]] += item["quantity"]
         
     | 
| 
       70 
     | 
    
         
            -
             
     | 
| 
       71 
     | 
    
         
            -
                    self.inventory_counters.append(counter)
         
     | 
| 
       72 
     | 
    
         
            -
             
     | 
| 
       73 
     | 
    
         
            -
                def add_image_to_history(self, image):
         
     | 
| 
       74 
     | 
    
         
            -
                    self.images.append(image)
         
     | 
| 
       75 
     | 
    
         
            -
             
     | 
| 
       76 
     | 
    
         
            -
                def add_observation_to_history(self, observation: dict):
         
     | 
| 
       77 
     | 
    
         
            -
                    if observation is None:
         
     | 
| 
       78 
     | 
    
         
            -
                        return
         
     | 
| 
       79 
     | 
    
         
            -
                    if "inventory" in observation:
         
     | 
| 
       80 
     | 
    
         
            -
                        clean_inv = []
         
     | 
| 
       81 
     | 
    
         
            -
                        # remove empty slots
         
     | 
| 
       82 
     | 
    
         
            -
                        for item in observation["inventory"]:
         
     | 
| 
       83 
     | 
    
         
            -
                            if item["quantity"] > 0:
         
     | 
| 
       84 
     | 
    
         
            -
                                clean_inv.append(item)
         
     | 
| 
       85 
     | 
    
         
            -
                        self.add_inventory_to_history(clean_inv)
         
     | 
| 
       86 
     | 
    
         
            -
                    if "pov" in observation:
         
     | 
| 
       87 
     | 
    
         
            -
                        self.add_image_to_history(observation["pov"])
         
     | 
| 
       88 
     | 
    
         
            -
             
     | 
| 
       89 
     | 
    
         
            -
                def __str__(self):
         
     | 
| 
       90 
     | 
    
         
            -
                    return str(self.dialogue_history)
         
     | 
| 
       91 
     | 
    
         
            -
             
     | 
| 
       92 
     | 
    
         
            -
                def reset(self, objective: str = "", initial_dialogue: list[dict] = []):
         
     | 
| 
       93 
     | 
    
         
            -
                    self.dialogue_history = initial_dialogue
         
     | 
| 
       94 
     | 
    
         
            -
                    self.action_history = []
         
     | 
| 
       95 
     | 
    
         
            -
                    self.inventory_history = []
         
     | 
| 
       96 
     | 
    
         
            -
                    self.inventory_counters = []
         
     | 
| 
       97 
     | 
    
         
            -
                    self.images = []
         
     | 
| 
       98 
     | 
    
         
            -
                    self.objective = objective
         
     | 
| 
       99 
     | 
    
         
            -
             
     | 
| 
       100 
     | 
    
         
            -
                def set_objective(self, objective: str):
         
     | 
| 
       101 
     | 
    
         
            -
                    self.objective = objective
         
     | 
| 
       102 
     | 
    
         
            -
             
     | 
| 
       103 
     | 
    
         
            -
                def trace(self):
         
     | 
| 
       104 
     | 
    
         
            -
                    return {
         
     | 
| 
       105 
     | 
    
         
            -
                        "dialogue_history": copy(
         
     | 
| 
       106 
     | 
    
         
            -
                            self.dialogue_history[self.initial_dialogue_length :]
         
     | 
| 
       107 
     | 
    
         
            -
                        ),
         
     | 
| 
       108 
     | 
    
         
            -
                        "action_history": copy(self.action_history),
         
     | 
| 
       109 
     | 
    
         
            -
                        "inventory_history": copy(self.inventory_history),
         
     | 
| 
       110 
     | 
    
         
            -
                        "objective": copy(self.objective),
         
     | 
| 
       111 
     | 
    
         
            -
                        "tokens_used": copy(self.tokens_used),
         
     | 
| 
       112 
     | 
    
         
            -
                    }
         
     | 
| 
       113 
     | 
    
         
            -
             
     | 
| 
       114 
     | 
    
         
            -
                @property
         
     | 
| 
       115 
     | 
    
         
            -
                def num_steps(self):
         
     | 
| 
       116 
     | 
    
         
            -
                    return len(self.action_history)
         
     | 
| 
       117 
     | 
    
         
            -
             
     | 
| 
       118 
     | 
    
         
            -
                def check_stuck(self, max_steps_no_change: int = 10) -> bool:
         
     | 
| 
       119 
     | 
    
         
            -
                    """
         
     | 
| 
       120 
     | 
    
         
            -
                    If inventory content does not change for max_steps_no_change steps
         
     | 
| 
       121 
     | 
    
         
            -
                    the agent is considered stuck.
         
     | 
| 
       122 
     | 
    
         
            -
             
     | 
| 
       123 
     | 
    
         
            -
                    With N=10, the oracle solver can still solve 100% of the examples
         
     | 
| 
       124 
     | 
    
         
            -
                    """
         
     | 
| 
       125 
     | 
    
         
            -
                    if len(self.inventory_counters) <= max_steps_no_change:
         
     | 
| 
       126 
     | 
    
         
            -
                        return False
         
     | 
| 
       127 
     | 
    
         
            -
             
     | 
| 
       128 
     | 
    
         
            -
                    return all(
         
     | 
| 
       129 
     | 
    
         
            -
                        c == self.inventory_counters[-max_steps_no_change - 1]
         
     | 
| 
       130 
     | 
    
         
            -
                        for c in self.inventory_counters[-max_steps_no_change - 1 :]
         
     | 
| 
       131 
     | 
    
         
            -
                    )
         
     | 
| 
       132 
     | 
    
         
            -
             
     | 
| 
       133 
     | 
    
         
            -
             
     | 
| 
       134 
     | 
    
         
            -
            class ABCModel(abc.ABC):
         
     | 
| 
       135 
     | 
    
         
            -
                """
         
     | 
| 
       136 
     | 
    
         
            -
                Model class must implement the following methods to work with evaluator
         
     | 
| 
       137 
     | 
    
         
            -
                """
         
     | 
| 
       138 
     | 
    
         
            -
             
     | 
| 
       139 
     | 
    
         
            -
                @abc.abstractmethod
         
     | 
| 
       140 
     | 
    
         
            -
                def step(
         
     | 
| 
       141 
     | 
    
         
            -
                    self, observation: list[dict]
         
     | 
| 
       142 
     | 
    
         
            -
                ) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
         
     | 
| 
       143 
     | 
    
         
            -
                    """
         
     | 
| 
       144 
     | 
    
         
            -
                    Model should output a valid action based on the 3 types available
         
     | 
| 
       145 
     | 
    
         
            -
             
     | 
| 
       146 
     | 
    
         
            -
                    Note this is a batch operation, so the model should return a list of actions
         
     | 
| 
       147 
     | 
    
         
            -
                    for each observation in the batch
         
     | 
| 
       148 
     | 
    
         
            -
                    """
         
     | 
| 
       149 
     | 
    
         
            -
                    raise NotImplementedError()
         
     | 
| 
       150 
     | 
    
         
            -
             
     | 
| 
       151 
     | 
    
         
            -
                def reset_history(self, objective: str = ""):
         
     | 
| 
       152 
     | 
    
         
            -
                    self.history.reset(objective=objective)
         
     |