job-shop-lib 1.0.0b5__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. job_shop_lib/__init__.py +1 -1
  2. job_shop_lib/_job_shop_instance.py +2 -2
  3. job_shop_lib/_operation.py +9 -3
  4. job_shop_lib/_scheduled_operation.py +3 -0
  5. job_shop_lib/benchmarking/__init__.py +1 -0
  6. job_shop_lib/dispatching/__init__.py +12 -10
  7. job_shop_lib/dispatching/_dispatcher.py +6 -13
  8. job_shop_lib/dispatching/_factories.py +3 -3
  9. job_shop_lib/dispatching/_optimal_operations_observer.py +0 -2
  10. job_shop_lib/dispatching/_ready_operation_filters.py +4 -4
  11. job_shop_lib/dispatching/feature_observers/_composite_feature_observer.py +11 -6
  12. job_shop_lib/dispatching/feature_observers/_factory.py +8 -3
  13. job_shop_lib/dispatching/feature_observers/_feature_observer.py +1 -1
  14. job_shop_lib/dispatching/feature_observers/_is_completed_observer.py +35 -67
  15. job_shop_lib/dispatching/rules/__init__.py +11 -8
  16. job_shop_lib/dispatching/rules/_dispatching_rule_factory.py +1 -1
  17. job_shop_lib/dispatching/rules/_machine_chooser_factory.py +3 -2
  18. job_shop_lib/generation/__init__.py +12 -1
  19. job_shop_lib/graphs/__init__.py +42 -8
  20. job_shop_lib/graphs/_build_resource_task_graphs.py +1 -1
  21. job_shop_lib/graphs/_job_shop_graph.py +38 -19
  22. job_shop_lib/graphs/graph_updaters/__init__.py +5 -1
  23. job_shop_lib/graphs/graph_updaters/_disjunctive_graph_updater.py +108 -0
  24. job_shop_lib/graphs/graph_updaters/_residual_graph_updater.py +3 -1
  25. job_shop_lib/graphs/graph_updaters/_utils.py +2 -2
  26. job_shop_lib/py.typed +0 -0
  27. job_shop_lib/reinforcement_learning/__init__.py +13 -7
  28. job_shop_lib/reinforcement_learning/_multi_job_shop_graph_env.py +1 -1
  29. job_shop_lib/reinforcement_learning/_resource_task_graph_observation.py +102 -24
  30. job_shop_lib/reinforcement_learning/_single_job_shop_graph_env.py +11 -2
  31. job_shop_lib/reinforcement_learning/_types_and_constants.py +11 -10
  32. job_shop_lib/reinforcement_learning/_utils.py +29 -0
  33. job_shop_lib/visualization/gantt/__init__.py +7 -3
  34. job_shop_lib/visualization/gantt/_gantt_chart_video_and_gif_creation.py +5 -2
  35. job_shop_lib/visualization/graphs/__init__.py +1 -0
  36. job_shop_lib/visualization/graphs/_plot_disjunctive_graph.py +53 -19
  37. {job_shop_lib-1.0.0b5.dist-info → job_shop_lib-1.0.2.dist-info}/METADATA +19 -18
  38. {job_shop_lib-1.0.0b5.dist-info → job_shop_lib-1.0.2.dist-info}/RECORD +40 -38
  39. {job_shop_lib-1.0.0b5.dist-info → job_shop_lib-1.0.2.dist-info}/LICENSE +0 -0
  40. {job_shop_lib-1.0.0b5.dist-info → job_shop_lib-1.0.2.dist-info}/WHEEL +0 -0
@@ -132,15 +132,13 @@ class JobShopGraph:
132
132
 
133
133
  This method assigns a unique identifier to the node, adds it to the
134
134
  graph, and updates the nodes list and the nodes_by_type dictionary. If
135
- the node is of type `OPERATION`, it also updates `nodes_by_job` and
136
- `nodes_by_machine` based on the operation's job_id and machine_ids.
135
+ the node is of type :class:`NodeType.OPERATION`, it also updates
136
+ ``nodes_by_job`` and ``nodes_by_machine`` based on the operation's
137
+ job id and machine ids.
137
138
 
138
139
  Args:
139
- node_for_adding (Node): The node to be added to the graph.
140
-
141
- Raises:
142
- ValueError: If the node type is unsupported or if required
143
- attributes for the node type are missing.
140
+ node_for_adding:
141
+ The node to be added to the graph.
144
142
 
145
143
  Note:
146
144
  This method directly modifies the graph attribute as well as
@@ -171,17 +169,25 @@ class JobShopGraph:
171
169
  ) -> None:
172
170
  """Adds an edge to the graph.
173
171
 
172
+ It automatically determines the edge type based on the source and
173
+ destination nodes unless explicitly provided in the ``attr`` argument
174
+ via the ``type`` key. The edge type is a tuple of strings:
175
+ ``(source_node_type, "to", destination_node_type)``.
176
+
174
177
  Args:
175
- u_of_edge: The source node of the edge. If it is a `Node`, its
176
- `node_id` is used as the source. Otherwise, it is assumed to be
177
- the node_id of the source.
178
- v_of_edge: The destination node of the edge. If it is a `Node`, its
179
- `node_id` is used as the destination. Otherwise, it is assumed
180
- to be the node_id of the destination.
181
- **attr: Additional attributes to be added to the edge.
178
+ u_of_edge:
179
+ The source node of the edge. If it is a :class:`Node`, its
180
+ ``node_id`` is used as the source. Otherwise, it is assumed to
181
+ be the ``node_id`` of the source.
182
+ v_of_edge:
183
+ The destination node of the edge. If it is a :class:`Node`,
184
+ its ``node_id`` is used as the destination. Otherwise, it
185
+ is assumed to be the ``node_id`` of the destination.
186
+ **attr:
187
+ Additional attributes to be added to the edge.
182
188
 
183
189
  Raises:
184
- ValidationError: If `u_of_edge` or `v_of_edge` are not in the
190
+ ValidationError: If ``u_of_edge`` or ``v_of_edge`` are not in the
185
191
  graph.
186
192
  """
187
193
  if isinstance(u_of_edge, Node):
@@ -192,18 +198,30 @@ class JobShopGraph:
192
198
  raise ValidationError(
193
199
  "`u_of_edge` and `v_of_edge` must be in the graph."
194
200
  )
195
- self.graph.add_edge(u_of_edge, v_of_edge, **attr)
201
+ edge_type = attr.pop("type", None)
202
+ if edge_type is None:
203
+ u_node = self.nodes[u_of_edge]
204
+ v_node = self.nodes[v_of_edge]
205
+ edge_type = (
206
+ u_node.node_type.name.lower(),
207
+ "to",
208
+ v_node.node_type.name.lower(),
209
+ )
210
+ self.graph.add_edge(u_of_edge, v_of_edge, type=edge_type, **attr)
196
211
 
197
212
  def remove_node(self, node_id: int) -> None:
198
213
  """Removes a node from the graph and the isolated nodes that result
199
214
  from the removal.
200
215
 
201
216
  Args:
202
- node_id: The id of the node to remove.
217
+ node_id:
218
+ The id of the node to remove.
203
219
  """
204
220
  self.graph.remove_node(node_id)
205
221
  self.removed_nodes[node_id] = True
206
222
 
223
+ def remove_isolated_nodes(self) -> None:
224
+ """Removes isolated nodes from the graph."""
207
225
  isolated_nodes = list(nx.isolates(self.graph))
208
226
  for isolated_node in isolated_nodes:
209
227
  self.removed_nodes[isolated_node] = True
@@ -214,9 +232,10 @@ class JobShopGraph:
214
232
  """Returns whether the node is removed from the graph.
215
233
 
216
234
  Args:
217
- node: The node to check. If it is a `Node`, its `node_id` is used
235
+ node:
236
+ The node to check. If it is a ``Node``, its `node_id` is used
218
237
  as the node to check. Otherwise, it is assumed to be the
219
- `node_id` of the node to check.
238
+ ``node_id`` of the node to check.
220
239
  """
221
240
  if isinstance(node, Node):
222
241
  node = node.node_id
@@ -4,9 +4,11 @@ job shop scheduling problem.
4
4
  Currently, the following classes and utilities are available:
5
5
 
6
6
  .. autosummary::
7
+ :nosignatures:
7
8
 
8
9
  GraphUpdater
9
10
  ResidualGraphUpdater
11
+ DisjunctiveGraphUpdater
10
12
  remove_completed_operations
11
13
 
12
14
  """
@@ -14,10 +16,12 @@ Currently, the following classes and utilities are available:
14
16
  from ._graph_updater import GraphUpdater
15
17
  from ._utils import remove_completed_operations
16
18
  from ._residual_graph_updater import ResidualGraphUpdater
19
+ from ._disjunctive_graph_updater import DisjunctiveGraphUpdater
17
20
 
18
21
 
19
22
  __all__ = [
20
23
  "GraphUpdater",
21
- "remove_completed_operations",
22
24
  "ResidualGraphUpdater",
25
+ "DisjunctiveGraphUpdater",
26
+ "remove_completed_operations",
23
27
  ]
@@ -0,0 +1,108 @@
1
+ """Home of the `ResidualGraphUpdater` class."""
2
+
3
+ from job_shop_lib import ScheduledOperation
4
+ from job_shop_lib.graphs.graph_updaters import (
5
+ ResidualGraphUpdater,
6
+ )
7
+ from job_shop_lib.exceptions import ValidationError
8
+
9
+
10
+ class DisjunctiveGraphUpdater(ResidualGraphUpdater):
11
+ """Updates the graph based on the completed operations.
12
+
13
+ This observer updates the graph by removing the completed
14
+ operation, machine and job nodes from the graph. It subscribes to the
15
+ :class:`~job_shop_lib.dispatching.feature_observers.IsCompletedObserver`
16
+ to determine which operations, machines and jobs have been completed.
17
+
18
+ After an operation is dispatched, one of two disjunctive arcs that
19
+ connected it with the previous operation is dropped. Similarly, the
20
+ disjunctive arcs associated with the previous scheduled operation are
21
+ removed.
22
+
23
+ Attributes:
24
+ remove_completed_machine_nodes:
25
+ If ``True``, removes completed machine nodes from the graph.
26
+ remove_completed_job_nodes:
27
+ If ``True``, removes completed job nodes from the graph.
28
+
29
+ Args:
30
+ dispatcher:
31
+ The dispatcher instance to observe.
32
+ job_shop_graph:
33
+ The job shop graph to update.
34
+ subscribe:
35
+ If ``True``, automatically subscribes the observer to the
36
+ dispatcher. Defaults to ``True``.
37
+ remove_completed_machine_nodes:
38
+ If ``True``, removes completed machine nodes from the graph.
39
+ Defaults to ``True``.
40
+ remove_completed_job_nodes:
41
+ If ``True``, removes completed job nodes from the graph.
42
+ Defaults to ``True``.
43
+ """
44
+
45
+ def update(self, scheduled_operation: ScheduledOperation) -> None:
46
+ """Updates the disjunctive graph.
47
+
48
+ After an operation is dispatched, one of two arcs that connected it
49
+ with the previous operation is dropped. Similarly, the disjunctive
50
+ arcs associated with the previous scheduled operation are removed.
51
+
52
+ Args:
53
+ scheduled_operation:
54
+ The scheduled operation that was dispatched.
55
+ """
56
+ super().update(scheduled_operation)
57
+ machine_schedule = self.dispatcher.schedule.schedule[
58
+ scheduled_operation.machine_id
59
+ ]
60
+ if len(machine_schedule) <= 1:
61
+ return
62
+
63
+ previous_scheduled_operation = machine_schedule[-2]
64
+
65
+ # Remove the disjunctive arcs between the scheduled operation and the
66
+ # previous operation
67
+ scheduled_operation_node = self.job_shop_graph.nodes[
68
+ scheduled_operation.operation.operation_id
69
+ ]
70
+ if (
71
+ scheduled_operation_node.operation
72
+ is not scheduled_operation.operation
73
+ ):
74
+ raise ValidationError(
75
+ "Scheduled operation node does not match scheduled operation."
76
+ "Make sure that the operation nodes have been the first to be "
77
+ "added to the graph. This method assumes that the operation id"
78
+ " and node id are the same."
79
+ )
80
+ scheduled_id = scheduled_operation_node.node_id
81
+ assert scheduled_id == scheduled_operation.operation.operation_id
82
+ previous_id = previous_scheduled_operation.operation.operation_id
83
+ if self.job_shop_graph.is_removed(
84
+ previous_id
85
+ ) or self.job_shop_graph.is_removed(scheduled_id):
86
+ return
87
+ self.job_shop_graph.graph.remove_edge(scheduled_id, previous_id)
88
+
89
+ # Now, remove all the disjunctive edges between the previous scheduled
90
+ # operation and the other operations in the machine schedule
91
+ operations_with_same_machine = (
92
+ self.dispatcher.instance.operations_by_machine[
93
+ scheduled_operation.machine_id
94
+ ]
95
+ )
96
+ already_scheduled_operations = {
97
+ scheduled_op.operation.operation_id
98
+ for scheduled_op in machine_schedule
99
+ }
100
+ for operation in operations_with_same_machine:
101
+ if operation.operation_id in already_scheduled_operations:
102
+ continue
103
+ self.job_shop_graph.graph.remove_edge(
104
+ previous_id, operation.operation_id
105
+ )
106
+ self.job_shop_graph.graph.remove_edge(
107
+ operation.operation_id, previous_id
108
+ )
@@ -112,7 +112,9 @@ class ResidualGraphUpdater(GraphUpdater):
112
112
  """Updates the residual graph based on the completed operations."""
113
113
  remove_completed_operations(
114
114
  self.job_shop_graph,
115
- completed_operations=self.dispatcher.completed_operations(),
115
+ completed_operations=(
116
+ op.operation for op in self.dispatcher.completed_operations()
117
+ ),
116
118
  )
117
119
  graph_has_machine_nodes = bool(
118
120
  self.job_shop_graph.nodes_by_type[NodeType.MACHINE]
@@ -1,4 +1,4 @@
1
- """Contains grah updater functions to update """
1
+ """Contains utility functions for updating the job shop graph."""
2
2
 
3
3
  from collections.abc import Iterable
4
4
 
@@ -13,7 +13,7 @@ def remove_completed_operations(
13
13
  """Removes the operation node of the scheduled operation from the graph.
14
14
 
15
15
  Args:
16
- graph:
16
+ job_shop_graph:
17
17
  The job shop graph to update.
18
18
  dispatcher:
19
19
  The dispatcher instance.
job_shop_lib/py.typed ADDED
File without changes
@@ -1,20 +1,24 @@
1
1
  """Contains reinforcement learning components.
2
2
 
3
3
 
4
+
4
5
  .. autosummary::
6
+ :nosignatures:
5
7
 
6
8
  SingleJobShopGraphEnv
7
9
  MultiJobShopGraphEnv
8
10
  ObservationDict
9
11
  ObservationSpaceKey
12
+ ResourceTaskGraphObservation
13
+ ResourceTaskGraphObservationDict
10
14
  RewardObserver
11
15
  MakespanReward
12
16
  IdleTimeReward
13
17
  RenderConfig
14
18
  add_padding
15
19
  create_edge_type_dict
16
- ResourceTaskGraphObservation
17
- ResourceTaskGraphObservationDict
20
+ map_values
21
+ get_optimal_actions
18
22
 
19
23
  """
20
24
 
@@ -34,6 +38,7 @@ from job_shop_lib.reinforcement_learning._utils import (
34
38
  add_padding,
35
39
  create_edge_type_dict,
36
40
  map_values,
41
+ get_optimal_actions,
37
42
  )
38
43
 
39
44
  from job_shop_lib.reinforcement_learning._single_job_shop_graph_env import (
@@ -48,17 +53,18 @@ from ._resource_task_graph_observation import (
48
53
 
49
54
 
50
55
  __all__ = [
56
+ "SingleJobShopGraphEnv",
57
+ "MultiJobShopGraphEnv",
58
+ "ObservationDict",
51
59
  "ObservationSpaceKey",
60
+ "ResourceTaskGraphObservation",
61
+ "ResourceTaskGraphObservationDict",
52
62
  "RewardObserver",
53
63
  "MakespanReward",
54
64
  "IdleTimeReward",
55
- "SingleJobShopGraphEnv",
56
65
  "RenderConfig",
57
- "ObservationDict",
58
66
  "add_padding",
59
- "MultiJobShopGraphEnv",
60
67
  "create_edge_type_dict",
61
- "ResourceTaskGraphObservation",
62
68
  "map_values",
63
- "ResourceTaskGraphObservationDict",
69
+ "get_optimal_actions",
64
70
  ]
@@ -117,7 +117,7 @@ class MultiJobShopGraphEnv(gym.Env):
117
117
  graph_initializer:
118
118
  Function to create the initial graph representation.
119
119
  If ``None``, the default graph initializer is used:
120
- :func:`~job_shop_lib.graphs.build_agent_task_graph`.
120
+ :func:`~job_shop_lib.graphs.build_resource_task_graph`.
121
121
  graph_updater_config:
122
122
  Configuration for the graph updater. The graph updater is used
123
123
  to update the graph representation after each action. If
@@ -1,6 +1,6 @@
1
1
  """Contains wrappers for the environments."""
2
2
 
3
- from typing import TypeVar, TypedDict, Generic
3
+ from typing import TypeVar, TypedDict, Generic, Any
4
4
  from gymnasium import ObservationWrapper
5
5
  import numpy as np
6
6
  from numpy.typing import NDArray
@@ -20,11 +20,22 @@ EnvType = TypeVar( # pylint: disable=invalid-name
20
20
  "EnvType", bound=SingleJobShopGraphEnv | MultiJobShopGraphEnv
21
21
  )
22
22
 
23
+ _NODE_TYPE_TO_FEATURE_TYPE = {
24
+ NodeType.OPERATION: FeatureType.OPERATIONS,
25
+ NodeType.MACHINE: FeatureType.MACHINES,
26
+ NodeType.JOB: FeatureType.JOBS,
27
+ }
28
+ _FEATURE_TYPE_STR_TO_NODE_TYPE = {
29
+ FeatureType.OPERATIONS.value: NodeType.OPERATION,
30
+ FeatureType.MACHINES.value: NodeType.MACHINE,
31
+ FeatureType.JOBS.value: NodeType.JOB,
32
+ }
33
+
23
34
 
24
35
  class ResourceTaskGraphObservationDict(TypedDict):
25
36
  """Represents a dictionary for resource task graph observations."""
26
37
 
27
- edge_index_dict: dict[str, NDArray[np.int64]]
38
+ edge_index_dict: dict[tuple[str, str, str], NDArray[np.int32]]
28
39
  node_features_dict: dict[str, NDArray[np.float32]]
29
40
  original_ids_dict: dict[str, NDArray[np.int32]]
30
41
 
@@ -40,6 +51,12 @@ class ResourceTaskGraphObservation(ObservationWrapper, Generic[EnvType]):
40
51
  ``node_type_j`` are the node types of the source and target nodes,
41
52
  respectively.
42
53
 
54
+ Additionally, the node features are stored in a dictionary with keys
55
+ corresponding to the node type names under the ``node_features_dict`` key.
56
+
57
+ The node IDs are mapped to local IDs starting from 0. The
58
+ ``original_ids_dict`` contains the original node IDs before removing nodes.
59
+
43
60
  Attributes:
44
61
  global_to_local_id: A dictionary mapping global node IDs to local node
45
62
  IDs for each node type.
@@ -55,6 +72,7 @@ class ResourceTaskGraphObservation(ObservationWrapper, Generic[EnvType]):
55
72
  self.env = env # Unnecessary, but makes mypy happy
56
73
  self.global_to_local_id = self._compute_id_mappings()
57
74
  self.type_ranges = self._compute_node_type_ranges()
75
+ self._start_from_zero_mapping: dict[str, dict[int, int]] = {}
58
76
 
59
77
  def step(self, action: tuple[int, int]):
60
78
  """Takes a step in the environment.
@@ -80,7 +98,9 @@ class ResourceTaskGraphObservation(ObservationWrapper, Generic[EnvType]):
80
98
  machine_id, job_id).
81
99
  """
82
100
  observation, reward, done, truncated, info = self.env.step(action)
83
- return self.observation(observation), reward, done, truncated, info
101
+ new_observation = self.observation(observation)
102
+ new_info = self._info(info)
103
+ return new_observation, reward, done, truncated, new_info
84
104
 
85
105
  def reset(self, *, seed: int | None = None, options: dict | None = None):
86
106
  """Resets the environment.
@@ -104,7 +124,34 @@ class ResourceTaskGraphObservation(ObservationWrapper, Generic[EnvType]):
104
124
  (operation_id, machine_id, job_id).
105
125
  """
106
126
  observation, info = self.env.reset()
107
- return self.observation(observation), info
127
+ new_observation = self.observation(observation)
128
+ new_info = self._info(info)
129
+ return new_observation, new_info
130
+
131
+ def _info(self, info: dict[str, Any]) -> dict[str, Any]:
132
+ """Updates the "available_operations_with_ids" key in the info
133
+ dictionary so that they start from 0 using the
134
+ `_start_from_zero_mapping` attribute.
135
+ """
136
+ new_available_operations_ids = []
137
+ for operation_id, machine_id, job_id in info[
138
+ "available_operations_with_ids"
139
+ ]:
140
+ if "operation" in self._start_from_zero_mapping:
141
+ operation_id = self._start_from_zero_mapping["operation"][
142
+ operation_id
143
+ ]
144
+ if "machine" in self._start_from_zero_mapping:
145
+ machine_id = self._start_from_zero_mapping["machine"][
146
+ machine_id
147
+ ]
148
+ if "job" in self._start_from_zero_mapping:
149
+ job_id = self._start_from_zero_mapping["job"][job_id]
150
+ new_available_operations_ids.append(
151
+ (operation_id, machine_id, job_id)
152
+ )
153
+ info["available_operations_with_ids"] = new_available_operations_ids
154
+ return info
108
155
 
109
156
  def _compute_id_mappings(self) -> dict[int, int]:
110
157
  """Computes mappings from global node IDs to type-local IDs.
@@ -145,21 +192,50 @@ class ResourceTaskGraphObservation(ObservationWrapper, Generic[EnvType]):
145
192
 
146
193
  return type_ranges
147
194
 
148
- def observation(self, observation: ObservationDict):
195
+ def observation(
196
+ self, observation: ObservationDict
197
+ ) -> ResourceTaskGraphObservationDict:
198
+ """Processes the observation data into the resource task graph format.
199
+
200
+ Args:
201
+ observation: The observation dictionary. It must NOT have padding.
202
+
203
+ Returns:
204
+ A dictionary containing the following keys:
205
+
206
+ - "edge_index_dict": A dictionary mapping edge types to edge index
207
+ arrays.
208
+ - "node_features_dict": A dictionary mapping node type names to
209
+ node feature arrays.
210
+ - "original_ids_dict": A dictionary mapping node type names to the
211
+ original node IDs before removing nodes.
212
+ """
149
213
  edge_index_dict = create_edge_type_dict(
150
214
  observation["edge_index"],
151
215
  type_ranges=self.type_ranges,
152
216
  relationship="to",
153
217
  )
218
+ node_features_dict = self._create_node_features_dict(observation)
219
+ node_features_dict, original_ids_dict = self._remove_nodes(
220
+ node_features_dict, observation["removed_nodes"]
221
+ )
222
+
154
223
  # mapping from global node ID to local node ID
155
224
  for key, edge_index in edge_index_dict.items():
156
225
  edge_index_dict[key] = map_values(
157
226
  edge_index, self.global_to_local_id
158
227
  )
159
- node_features_dict = self._create_node_features_dict(observation)
160
- node_features_dict, original_ids_dict = self._remove_nodes(
161
- node_features_dict, observation["removed_nodes"]
228
+ # mapping so that ids start from 0 in edge index
229
+ self._start_from_zero_mapping = self._get_start_from_zero_mappings(
230
+ original_ids_dict
162
231
  )
232
+ for (type_1, to, type_2), edge_index in edge_index_dict.items():
233
+ edge_index_dict[(type_1, to, type_2)][0] = map_values(
234
+ edge_index[0], self._start_from_zero_mapping[type_1]
235
+ )
236
+ edge_index_dict[(type_1, to, type_2)][1] = map_values(
237
+ edge_index[1], self._start_from_zero_mapping[type_2]
238
+ )
163
239
 
164
240
  return {
165
241
  "edge_index_dict": edge_index_dict,
@@ -167,6 +243,15 @@ class ResourceTaskGraphObservation(ObservationWrapper, Generic[EnvType]):
167
243
  "original_ids_dict": original_ids_dict,
168
244
  }
169
245
 
246
+ @staticmethod
247
+ def _get_start_from_zero_mappings(
248
+ original_indices_dict: dict[str, NDArray[np.int32]]
249
+ ) -> dict[str, dict[int, int]]:
250
+ mappings = {}
251
+ for key, indices in original_indices_dict.items():
252
+ mappings[key] = {idx: i for i, idx in enumerate(indices)}
253
+ return mappings
254
+
170
255
  def _create_node_features_dict(
171
256
  self, observation: ObservationDict
172
257
  ) -> dict[str, NDArray]:
@@ -178,14 +263,10 @@ class ResourceTaskGraphObservation(ObservationWrapper, Generic[EnvType]):
178
263
  Returns:
179
264
  Dictionary mapping node type names to node features.
180
265
  """
181
- node_type_to_feature_type = {
182
- NodeType.OPERATION: FeatureType.OPERATIONS,
183
- NodeType.MACHINE: FeatureType.MACHINES,
184
- NodeType.JOB: FeatureType.JOBS,
185
- }
266
+
186
267
  node_features_dict = {}
187
- for node_type, feature_type in node_type_to_feature_type.items():
188
- if node_type in self.unwrapped.job_shop_graph.nodes_by_type:
268
+ for node_type, feature_type in _NODE_TYPE_TO_FEATURE_TYPE.items():
269
+ if self.unwrapped.job_shop_graph.nodes_by_type[node_type]:
189
270
  node_features_dict[feature_type.value] = observation[
190
271
  feature_type.value
191
272
  ]
@@ -211,9 +292,9 @@ class ResourceTaskGraphObservation(ObservationWrapper, Generic[EnvType]):
211
292
 
212
293
  def _remove_nodes(
213
294
  self,
214
- node_features_dict: dict[str, NDArray[np.float32]],
295
+ node_features_dict: dict[str, NDArray[T]],
215
296
  removed_nodes: NDArray[np.bool_],
216
- ) -> tuple[dict[str, NDArray[np.float32]], dict[str, NDArray[np.int32]]]:
297
+ ) -> tuple[dict[str, NDArray[T]], dict[str, NDArray[np.int32]]]:
217
298
  """Removes nodes from the node features dictionary.
218
299
 
219
300
  Args:
@@ -223,15 +304,12 @@ class ResourceTaskGraphObservation(ObservationWrapper, Generic[EnvType]):
223
304
  The node features dictionary with the nodes removed and a
224
305
  dictionary containing the original node ids.
225
306
  """
226
- removed_nodes_dict: dict[str, NDArray[np.float32]] = {}
307
+ removed_nodes_dict: dict[str, NDArray[T]] = {}
227
308
  original_ids_dict: dict[str, NDArray[np.int32]] = {}
228
- feature_type_to_node_type = {
229
- FeatureType.OPERATIONS.value: NodeType.OPERATION,
230
- FeatureType.MACHINES.value: NodeType.MACHINE,
231
- FeatureType.JOBS.value: NodeType.JOB,
232
- }
233
309
  for feature_type, features in node_features_dict.items():
234
- node_type = feature_type_to_node_type[feature_type].name.lower()
310
+ node_type = _FEATURE_TYPE_STR_TO_NODE_TYPE[
311
+ feature_type
312
+ ].name.lower()
235
313
  if node_type not in self.type_ranges:
236
314
  continue
237
315
  start, end = self.type_ranges[node_type]
@@ -2,7 +2,7 @@
2
2
 
3
3
  from copy import deepcopy
4
4
  from collections.abc import Callable, Sequence
5
- from typing import Any, Dict, Tuple, List, Optional, Type
5
+ from typing import Any, Dict, Tuple, List, Optional, Type, Union
6
6
 
7
7
  import matplotlib.pyplot as plt
8
8
  import gymnasium as gym
@@ -24,6 +24,8 @@ from job_shop_lib.dispatching import (
24
24
  from job_shop_lib.dispatching.feature_observers import (
25
25
  FeatureObserverConfig,
26
26
  CompositeFeatureObserver,
27
+ FeatureObserver,
28
+ FeatureObserverType,
27
29
  )
28
30
  from job_shop_lib.visualization.gantt import GanttChartCreator
29
31
  from job_shop_lib.reinforcement_learning import (
@@ -137,7 +139,14 @@ class SingleJobShopGraphEnv(gym.Env):
137
139
  def __init__(
138
140
  self,
139
141
  job_shop_graph: JobShopGraph,
140
- feature_observer_configs: Sequence[FeatureObserverConfig],
142
+ feature_observer_configs: Sequence[
143
+ Union[
144
+ str,
145
+ FeatureObserverType,
146
+ Type[FeatureObserver],
147
+ FeatureObserverConfig,
148
+ ],
149
+ ],
141
150
  reward_function_config: DispatcherObserverConfig[
142
151
  Type[RewardObserver]
143
152
  ] = DispatcherObserverConfig(class_type=MakespanReward),
@@ -5,6 +5,7 @@ from enum import Enum
5
5
  from typing import TypedDict
6
6
 
7
7
  import numpy as np
8
+ from numpy.typing import NDArray
8
9
 
9
10
  from job_shop_lib.dispatching.feature_observers import FeatureType
10
11
  from job_shop_lib.visualization.gantt import (
@@ -35,27 +36,27 @@ class ObservationSpaceKey(str, Enum):
35
36
  class _ObservationDictRequired(TypedDict):
36
37
  """Required fields for the observation dictionary."""
37
38
 
38
- removed_nodes: np.ndarray
39
- edge_index: np.ndarray
39
+ removed_nodes: NDArray[np.bool_]
40
+ edge_index: NDArray[np.int32]
40
41
 
41
42
 
42
43
  class _ObservationDictOptional(TypedDict, total=False):
43
44
  """Optional fields for the observation dictionary."""
44
45
 
45
- operations: np.ndarray
46
- jobs: np.ndarray
47
- machines: np.ndarray
46
+ operations: NDArray[np.float32]
47
+ jobs: NDArray[np.float32]
48
+ machines: NDArray[np.float32]
48
49
 
49
50
 
50
51
  class ObservationDict(_ObservationDictRequired, _ObservationDictOptional):
51
52
  """A dictionary containing the observation of the environment.
52
53
 
53
54
  Required fields:
54
- removed_nodes (np.ndarray): Binary vector indicating removed nodes.
55
- edge_index (np.ndarray): Edge list in COO format.
55
+ removed_nodes: Binary vector indicating removed nodes.
56
+ edge_index: Edge list in COO format.
56
57
 
57
58
  Optional fields:
58
- operations (np.ndarray): Matrix of operation features.
59
- jobs (np.ndarray): Matrix of job features.
60
- machines (np.ndarray): Matrix of machine features.
59
+ operations: Matrix of operation features.
60
+ jobs: Matrix of job features.
61
+ machines: Matrix of machine features.
61
62
  """
@@ -6,6 +6,7 @@ import numpy as np
6
6
  from numpy.typing import NDArray
7
7
 
8
8
  from job_shop_lib.exceptions import ValidationError
9
+ from job_shop_lib.dispatching import OptimalOperationsObserver
9
10
 
10
11
  T = TypeVar("T", bound=np.number)
11
12
 
@@ -164,6 +165,34 @@ def map_values(array: NDArray[T], mapping: dict[int, int]) -> NDArray[T]:
164
165
  ) from e
165
166
 
166
167
 
168
+ def get_optimal_actions(
169
+ optimal_ops_observer: OptimalOperationsObserver,
170
+ available_operations_with_ids: list[tuple[int, int, int]],
171
+ ) -> dict[tuple[int, int, int], int]:
172
+ """Indicates if each action is optimal according to a
173
+ :class:`~job_shop_lib.dispatching.OptimalOperationsObserver` instance.
174
+
175
+ Args:
176
+ optimal_ops_observer: The observer that provides optimal operations.
177
+ available_operations_with_ids: List of available operations with their
178
+ IDs (operation_id, machine_id, job_id).
179
+
180
+ Returns:
181
+ A dictionary mapping each tuple
182
+ (operation_id, machine_id, job_id) in the available actions to a binary
183
+ indicator (1 if optimal, 0 otherwise).
184
+ """
185
+ optimal_actions = {}
186
+ optimal_ops = optimal_ops_observer.optimal_available
187
+ optimal_ops_ids = [
188
+ (op.operation_id, op.machine_id, op.job_id) for op in optimal_ops
189
+ ]
190
+ for operation_id, machine_id, job_id in available_operations_with_ids:
191
+ is_optimal = (operation_id, machine_id, job_id) in optimal_ops_ids
192
+ optimal_actions[(operation_id, machine_id, job_id)] = int(is_optimal)
193
+ return optimal_actions
194
+
195
+
167
196
  if __name__ == "__main__":
168
197
  import doctest
169
198