job-shop-lib 1.0.0b4__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. job_shop_lib/__init__.py +1 -1
  2. job_shop_lib/_operation.py +9 -3
  3. job_shop_lib/_scheduled_operation.py +3 -0
  4. job_shop_lib/dispatching/_dispatcher.py +6 -13
  5. job_shop_lib/dispatching/_factories.py +3 -3
  6. job_shop_lib/dispatching/_optimal_operations_observer.py +0 -2
  7. job_shop_lib/dispatching/_ready_operation_filters.py +4 -4
  8. job_shop_lib/dispatching/feature_observers/_composite_feature_observer.py +10 -5
  9. job_shop_lib/dispatching/feature_observers/_factory.py +8 -3
  10. job_shop_lib/dispatching/feature_observers/_feature_observer.py +1 -1
  11. job_shop_lib/dispatching/feature_observers/_is_completed_observer.py +35 -67
  12. job_shop_lib/dispatching/rules/_dispatching_rule_factory.py +1 -1
  13. job_shop_lib/dispatching/rules/_machine_chooser_factory.py +3 -2
  14. job_shop_lib/graphs/__init__.py +2 -0
  15. job_shop_lib/graphs/_build_resource_task_graphs.py +1 -1
  16. job_shop_lib/graphs/_job_shop_graph.py +38 -19
  17. job_shop_lib/graphs/graph_updaters/__init__.py +3 -0
  18. job_shop_lib/graphs/graph_updaters/_disjunctive_graph_updater.py +108 -0
  19. job_shop_lib/graphs/graph_updaters/_residual_graph_updater.py +3 -1
  20. job_shop_lib/graphs/graph_updaters/_utils.py +2 -2
  21. job_shop_lib/py.typed +0 -0
  22. job_shop_lib/reinforcement_learning/__init__.py +4 -0
  23. job_shop_lib/reinforcement_learning/_multi_job_shop_graph_env.py +1 -1
  24. job_shop_lib/reinforcement_learning/_resource_task_graph_observation.py +117 -46
  25. job_shop_lib/reinforcement_learning/_single_job_shop_graph_env.py +11 -2
  26. job_shop_lib/reinforcement_learning/_types_and_constants.py +11 -10
  27. job_shop_lib/reinforcement_learning/_utils.py +29 -0
  28. job_shop_lib/visualization/gantt/_gantt_chart_video_and_gif_creation.py +5 -2
  29. job_shop_lib/visualization/graphs/_plot_disjunctive_graph.py +53 -19
  30. {job_shop_lib-1.0.0b4.dist-info → job_shop_lib-1.0.1.dist-info}/METADATA +4 -10
  31. {job_shop_lib-1.0.0b4.dist-info → job_shop_lib-1.0.1.dist-info}/RECORD +33 -31
  32. {job_shop_lib-1.0.0b4.dist-info → job_shop_lib-1.0.1.dist-info}/LICENSE +0 -0
  33. {job_shop_lib-1.0.0b4.dist-info → job_shop_lib-1.0.1.dist-info}/WHEEL +0 -0
@@ -13,6 +13,8 @@
13
13
  RenderConfig
14
14
  add_padding
15
15
  create_edge_type_dict
16
+ map_values
17
+ get_optimal_actions
16
18
  ResourceTaskGraphObservation
17
19
  ResourceTaskGraphObservationDict
18
20
 
@@ -34,6 +36,7 @@ from job_shop_lib.reinforcement_learning._utils import (
34
36
  add_padding,
35
37
  create_edge_type_dict,
36
38
  map_values,
39
+ get_optimal_actions,
37
40
  )
38
41
 
39
42
  from job_shop_lib.reinforcement_learning._single_job_shop_graph_env import (
@@ -61,4 +64,5 @@ __all__ = [
61
64
  "ResourceTaskGraphObservation",
62
65
  "map_values",
63
66
  "ResourceTaskGraphObservationDict",
67
+ "get_optimal_actions",
64
68
  ]
@@ -117,7 +117,7 @@ class MultiJobShopGraphEnv(gym.Env):
117
117
  graph_initializer:
118
118
  Function to create the initial graph representation.
119
119
  If ``None``, the default graph initializer is used:
120
- :func:`~job_shop_lib.graphs.build_agent_task_graph`.
120
+ :func:`~job_shop_lib.graphs.build_resource_task_graph`.
121
121
  graph_updater_config:
122
122
  Configuration for the graph updater. The graph updater is used
123
123
  to update the graph representation after each action. If
@@ -1,6 +1,6 @@
1
1
  """Contains wrappers for the environments."""
2
2
 
3
- from typing import TypeVar, TypedDict
3
+ from typing import TypeVar, TypedDict, Generic, Any
4
4
  from gymnasium import ObservationWrapper
5
5
  import numpy as np
6
6
  from numpy.typing import NDArray
@@ -12,23 +12,36 @@ from job_shop_lib.reinforcement_learning import (
12
12
  create_edge_type_dict,
13
13
  map_values,
14
14
  )
15
- from job_shop_lib.graphs import NodeType, JobShopGraph
16
- from job_shop_lib.exceptions import ValidationError
15
+ from job_shop_lib.graphs import NodeType
17
16
  from job_shop_lib.dispatching.feature_observers import FeatureType
18
17
 
19
18
  T = TypeVar("T", bound=np.number)
19
+ EnvType = TypeVar( # pylint: disable=invalid-name
20
+ "EnvType", bound=SingleJobShopGraphEnv | MultiJobShopGraphEnv
21
+ )
22
+
23
+ _NODE_TYPE_TO_FEATURE_TYPE = {
24
+ NodeType.OPERATION: FeatureType.OPERATIONS,
25
+ NodeType.MACHINE: FeatureType.MACHINES,
26
+ NodeType.JOB: FeatureType.JOBS,
27
+ }
28
+ _FEATURE_TYPE_STR_TO_NODE_TYPE = {
29
+ FeatureType.OPERATIONS.value: NodeType.OPERATION,
30
+ FeatureType.MACHINES.value: NodeType.MACHINE,
31
+ FeatureType.JOBS.value: NodeType.JOB,
32
+ }
20
33
 
21
34
 
22
35
  class ResourceTaskGraphObservationDict(TypedDict):
23
36
  """Represents a dictionary for resource task graph observations."""
24
37
 
25
- edge_index_dict: dict[str, NDArray[np.int64]]
38
+ edge_index_dict: dict[tuple[str, str, str], NDArray[np.int32]]
26
39
  node_features_dict: dict[str, NDArray[np.float32]]
27
40
  original_ids_dict: dict[str, NDArray[np.int32]]
28
41
 
29
42
 
30
43
  # pylint: disable=line-too-long
31
- class ResourceTaskGraphObservation(ObservationWrapper):
44
+ class ResourceTaskGraphObservation(ObservationWrapper, Generic[EnvType]):
32
45
  """Observation wrapper that converts an observation following the
33
46
  :class:`ObservationDict` format to a format suitable to PyG's
34
47
  [`HeteroData`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.HeteroData.html).
@@ -38,6 +51,12 @@ class ResourceTaskGraphObservation(ObservationWrapper):
38
51
  ``node_type_j`` are the node types of the source and target nodes,
39
52
  respectively.
40
53
 
54
+ Additionally, the node features are stored in a dictionary with keys
55
+ corresponding to the node type names under the ``node_features_dict`` key.
56
+
57
+ The node IDs are mapped to local IDs starting from 0. The
58
+ ``original_ids_dict`` contains the original node IDs before removing nodes.
59
+
41
60
  Attributes:
42
61
  global_to_local_id: A dictionary mapping global node IDs to local node
43
62
  IDs for each node type.
@@ -48,25 +67,12 @@ class ResourceTaskGraphObservation(ObservationWrapper):
48
67
  env: The environment to wrap.
49
68
  """
50
69
 
51
- def __init__(self, env: SingleJobShopGraphEnv | MultiJobShopGraphEnv):
70
+ def __init__(self, env: EnvType):
52
71
  super().__init__(env)
72
+ self.env = env # Unnecessary, but makes mypy happy
53
73
  self.global_to_local_id = self._compute_id_mappings()
54
74
  self.type_ranges = self._compute_node_type_ranges()
55
-
56
- @property
57
- def job_shop_graph(self) -> JobShopGraph:
58
- """Returns the job shop graph from the environment.
59
-
60
- Raises:
61
- ValidationError: If the environment is not an instance of
62
- ``SingleJobShopGraphEnv`` or ``MultiJobShopGraphEnv``.
63
- """
64
- if isinstance(self.env, (SingleJobShopGraphEnv, MultiJobShopGraphEnv)):
65
- return self.env.job_shop_graph
66
- raise ValidationError(
67
- "The environment must be an instance of "
68
- "SingleJobShopGraphEnv or MultiJobShopGraphEnv"
69
- )
75
+ self._start_from_zero_mapping: dict[str, dict[int, int]] = {}
70
76
 
71
77
  def step(self, action: tuple[int, int]):
72
78
  """Takes a step in the environment.
@@ -92,7 +98,9 @@ class ResourceTaskGraphObservation(ObservationWrapper):
92
98
  machine_id, job_id).
93
99
  """
94
100
  observation, reward, done, truncated, info = self.env.step(action)
95
- return self.observation(observation), reward, done, truncated, info
101
+ new_observation = self.observation(observation)
102
+ new_info = self._info(info)
103
+ return new_observation, reward, done, truncated, new_info
96
104
 
97
105
  def reset(self, *, seed: int | None = None, options: dict | None = None):
98
106
  """Resets the environment.
@@ -116,7 +124,34 @@ class ResourceTaskGraphObservation(ObservationWrapper):
116
124
  (operation_id, machine_id, job_id).
117
125
  """
118
126
  observation, info = self.env.reset()
119
- return self.observation(observation), info
127
+ new_observation = self.observation(observation)
128
+ new_info = self._info(info)
129
+ return new_observation, new_info
130
+
131
+ def _info(self, info: dict[str, Any]) -> dict[str, Any]:
132
+ """Updates the "available_operations_with_ids" key in the info
133
+ dictionary so that they start from 0 using the
134
+ `_start_from_zero_mapping` attribute.
135
+ """
136
+ new_available_operations_ids = []
137
+ for operation_id, machine_id, job_id in info[
138
+ "available_operations_with_ids"
139
+ ]:
140
+ if "operation" in self._start_from_zero_mapping:
141
+ operation_id = self._start_from_zero_mapping["operation"][
142
+ operation_id
143
+ ]
144
+ if "machine" in self._start_from_zero_mapping:
145
+ machine_id = self._start_from_zero_mapping["machine"][
146
+ machine_id
147
+ ]
148
+ if "job" in self._start_from_zero_mapping:
149
+ job_id = self._start_from_zero_mapping["job"][job_id]
150
+ new_available_operations_ids.append(
151
+ (operation_id, machine_id, job_id)
152
+ )
153
+ info["available_operations_with_ids"] = new_available_operations_ids
154
+ return info
120
155
 
121
156
  def _compute_id_mappings(self) -> dict[int, int]:
122
157
  """Computes mappings from global node IDs to type-local IDs.
@@ -127,7 +162,7 @@ class ResourceTaskGraphObservation(ObservationWrapper):
127
162
  """
128
163
  mappings = {}
129
164
  for node_type in NodeType:
130
- type_nodes = self.job_shop_graph.nodes_by_type[node_type]
165
+ type_nodes = self.unwrapped.job_shop_graph.nodes_by_type[node_type]
131
166
  if not type_nodes:
132
167
  continue
133
168
  # Create mapping from global ID to local ID
@@ -148,7 +183,7 @@ class ResourceTaskGraphObservation(ObservationWrapper):
148
183
  """
149
184
  type_ranges = {}
150
185
  for node_type in NodeType:
151
- type_nodes = self.job_shop_graph.nodes_by_type[node_type]
186
+ type_nodes = self.unwrapped.job_shop_graph.nodes_by_type[node_type]
152
187
  if not type_nodes:
153
188
  continue
154
189
  start = min(node.node_id for node in type_nodes)
@@ -157,21 +192,50 @@ class ResourceTaskGraphObservation(ObservationWrapper):
157
192
 
158
193
  return type_ranges
159
194
 
160
- def observation(self, observation: ObservationDict):
195
+ def observation(
196
+ self, observation: ObservationDict
197
+ ) -> ResourceTaskGraphObservationDict:
198
+ """Processes the observation data into the resource task graph format.
199
+
200
+ Args:
201
+ observation: The observation dictionary. It must NOT have padding.
202
+
203
+ Returns:
204
+ A dictionary containing the following keys:
205
+
206
+ - "edge_index_dict": A dictionary mapping edge types to edge index
207
+ arrays.
208
+ - "node_features_dict": A dictionary mapping node type names to
209
+ node feature arrays.
210
+ - "original_ids_dict": A dictionary mapping node type names to the
211
+ original node IDs before removing nodes.
212
+ """
161
213
  edge_index_dict = create_edge_type_dict(
162
214
  observation["edge_index"],
163
215
  type_ranges=self.type_ranges,
164
216
  relationship="to",
165
217
  )
218
+ node_features_dict = self._create_node_features_dict(observation)
219
+ node_features_dict, original_ids_dict = self._remove_nodes(
220
+ node_features_dict, observation["removed_nodes"]
221
+ )
222
+
166
223
  # mapping from global node ID to local node ID
167
224
  for key, edge_index in edge_index_dict.items():
168
225
  edge_index_dict[key] = map_values(
169
226
  edge_index, self.global_to_local_id
170
227
  )
171
- node_features_dict = self._create_node_features_dict(observation)
172
- node_features_dict, original_ids_dict = self._remove_nodes(
173
- node_features_dict, observation["removed_nodes"]
228
+ # mapping so that ids start from 0 in edge index
229
+ self._start_from_zero_mapping = self._get_start_from_zero_mappings(
230
+ original_ids_dict
174
231
  )
232
+ for (type_1, to, type_2), edge_index in edge_index_dict.items():
233
+ edge_index_dict[(type_1, to, type_2)][0] = map_values(
234
+ edge_index[0], self._start_from_zero_mapping[type_1]
235
+ )
236
+ edge_index_dict[(type_1, to, type_2)][1] = map_values(
237
+ edge_index[1], self._start_from_zero_mapping[type_2]
238
+ )
175
239
 
176
240
  return {
177
241
  "edge_index_dict": edge_index_dict,
@@ -179,6 +243,15 @@ class ResourceTaskGraphObservation(ObservationWrapper):
179
243
  "original_ids_dict": original_ids_dict,
180
244
  }
181
245
 
246
+ @staticmethod
247
+ def _get_start_from_zero_mappings(
248
+ original_indices_dict: dict[str, NDArray[np.int32]]
249
+ ) -> dict[str, dict[int, int]]:
250
+ mappings = {}
251
+ for key, indices in original_indices_dict.items():
252
+ mappings[key] = {idx: i for i, idx in enumerate(indices)}
253
+ return mappings
254
+
182
255
  def _create_node_features_dict(
183
256
  self, observation: ObservationDict
184
257
  ) -> dict[str, NDArray]:
@@ -190,14 +263,10 @@ class ResourceTaskGraphObservation(ObservationWrapper):
190
263
  Returns:
191
264
  Dictionary mapping node type names to node features.
192
265
  """
193
- node_type_to_feature_type = {
194
- NodeType.OPERATION: FeatureType.OPERATIONS,
195
- NodeType.MACHINE: FeatureType.MACHINES,
196
- NodeType.JOB: FeatureType.JOBS,
197
- }
266
+
198
267
  node_features_dict = {}
199
- for node_type, feature_type in node_type_to_feature_type.items():
200
- if node_type in self.job_shop_graph.nodes_by_type:
268
+ for node_type, feature_type in _NODE_TYPE_TO_FEATURE_TYPE.items():
269
+ if self.unwrapped.job_shop_graph.nodes_by_type[node_type]:
201
270
  node_features_dict[feature_type.value] = observation[
202
271
  feature_type.value
203
272
  ]
@@ -210,7 +279,7 @@ class ResourceTaskGraphObservation(ObservationWrapper):
210
279
  ]
211
280
  job_ids_of_ops = [
212
281
  node.operation.job_id
213
- for node in self.job_shop_graph.nodes_by_type[
282
+ for node in self.unwrapped.job_shop_graph.nodes_by_type[
214
283
  NodeType.OPERATION
215
284
  ]
216
285
  ]
@@ -223,9 +292,9 @@ class ResourceTaskGraphObservation(ObservationWrapper):
223
292
 
224
293
  def _remove_nodes(
225
294
  self,
226
- node_features_dict: dict[str, NDArray[np.float32]],
295
+ node_features_dict: dict[str, NDArray[T]],
227
296
  removed_nodes: NDArray[np.bool_],
228
- ) -> tuple[dict[str, NDArray[np.float32]], dict[str, NDArray[np.int32]]]:
297
+ ) -> tuple[dict[str, NDArray[T]], dict[str, NDArray[np.int32]]]:
229
298
  """Removes nodes from the node features dictionary.
230
299
 
231
300
  Args:
@@ -235,15 +304,12 @@ class ResourceTaskGraphObservation(ObservationWrapper):
235
304
  The node features dictionary with the nodes removed and a
236
305
  dictionary containing the original node ids.
237
306
  """
238
- removed_nodes_dict: dict[str, NDArray[np.float32]] = {}
307
+ removed_nodes_dict: dict[str, NDArray[T]] = {}
239
308
  original_ids_dict: dict[str, NDArray[np.int32]] = {}
240
- feature_type_to_node_type = {
241
- FeatureType.OPERATIONS.value: NodeType.OPERATION,
242
- FeatureType.MACHINES.value: NodeType.MACHINE,
243
- FeatureType.JOBS.value: NodeType.JOB,
244
- }
245
309
  for feature_type, features in node_features_dict.items():
246
- node_type = feature_type_to_node_type[feature_type].name.lower()
310
+ node_type = _FEATURE_TYPE_STR_TO_NODE_TYPE[
311
+ feature_type
312
+ ].name.lower()
247
313
  if node_type not in self.type_ranges:
248
314
  continue
249
315
  start, end = self.type_ranges[node_type]
@@ -256,3 +322,8 @@ class ResourceTaskGraphObservation(ObservationWrapper):
256
322
  )[0]
257
323
 
258
324
  return removed_nodes_dict, original_ids_dict
325
+
326
+ @property
327
+ def unwrapped(self) -> EnvType:
328
+ """Returns the unwrapped environment."""
329
+ return self.env # type: ignore[return-value]
@@ -2,7 +2,7 @@
2
2
 
3
3
  from copy import deepcopy
4
4
  from collections.abc import Callable, Sequence
5
- from typing import Any, Dict, Tuple, List, Optional, Type
5
+ from typing import Any, Dict, Tuple, List, Optional, Type, Union
6
6
 
7
7
  import matplotlib.pyplot as plt
8
8
  import gymnasium as gym
@@ -24,6 +24,8 @@ from job_shop_lib.dispatching import (
24
24
  from job_shop_lib.dispatching.feature_observers import (
25
25
  FeatureObserverConfig,
26
26
  CompositeFeatureObserver,
27
+ FeatureObserver,
28
+ FeatureObserverType,
27
29
  )
28
30
  from job_shop_lib.visualization.gantt import GanttChartCreator
29
31
  from job_shop_lib.reinforcement_learning import (
@@ -137,7 +139,14 @@ class SingleJobShopGraphEnv(gym.Env):
137
139
  def __init__(
138
140
  self,
139
141
  job_shop_graph: JobShopGraph,
140
- feature_observer_configs: Sequence[FeatureObserverConfig],
142
+ feature_observer_configs: Sequence[
143
+ Union[
144
+ str,
145
+ FeatureObserverType,
146
+ Type[FeatureObserver],
147
+ FeatureObserverConfig,
148
+ ],
149
+ ],
141
150
  reward_function_config: DispatcherObserverConfig[
142
151
  Type[RewardObserver]
143
152
  ] = DispatcherObserverConfig(class_type=MakespanReward),
@@ -5,6 +5,7 @@ from enum import Enum
5
5
  from typing import TypedDict
6
6
 
7
7
  import numpy as np
8
+ from numpy.typing import NDArray
8
9
 
9
10
  from job_shop_lib.dispatching.feature_observers import FeatureType
10
11
  from job_shop_lib.visualization.gantt import (
@@ -35,27 +36,27 @@ class ObservationSpaceKey(str, Enum):
35
36
  class _ObservationDictRequired(TypedDict):
36
37
  """Required fields for the observation dictionary."""
37
38
 
38
- removed_nodes: np.ndarray
39
- edge_index: np.ndarray
39
+ removed_nodes: NDArray[np.bool_]
40
+ edge_index: NDArray[np.int32]
40
41
 
41
42
 
42
43
  class _ObservationDictOptional(TypedDict, total=False):
43
44
  """Optional fields for the observation dictionary."""
44
45
 
45
- operations: np.ndarray
46
- jobs: np.ndarray
47
- machines: np.ndarray
46
+ operations: NDArray[np.float32]
47
+ jobs: NDArray[np.float32]
48
+ machines: NDArray[np.float32]
48
49
 
49
50
 
50
51
  class ObservationDict(_ObservationDictRequired, _ObservationDictOptional):
51
52
  """A dictionary containing the observation of the environment.
52
53
 
53
54
  Required fields:
54
- removed_nodes (np.ndarray): Binary vector indicating removed nodes.
55
- edge_index (np.ndarray): Edge list in COO format.
55
+ removed_nodes: Binary vector indicating removed nodes.
56
+ edge_index: Edge list in COO format.
56
57
 
57
58
  Optional fields:
58
- operations (np.ndarray): Matrix of operation features.
59
- jobs (np.ndarray): Matrix of job features.
60
- machines (np.ndarray): Matrix of machine features.
59
+ operations: Matrix of operation features.
60
+ jobs: Matrix of job features.
61
+ machines: Matrix of machine features.
61
62
  """
@@ -6,6 +6,7 @@ import numpy as np
6
6
  from numpy.typing import NDArray
7
7
 
8
8
  from job_shop_lib.exceptions import ValidationError
9
+ from job_shop_lib.dispatching import OptimalOperationsObserver
9
10
 
10
11
  T = TypeVar("T", bound=np.number)
11
12
 
@@ -164,6 +165,34 @@ def map_values(array: NDArray[T], mapping: dict[int, int]) -> NDArray[T]:
164
165
  ) from e
165
166
 
166
167
 
168
+ def get_optimal_actions(
169
+ optimal_ops_observer: OptimalOperationsObserver,
170
+ available_operations_with_ids: list[tuple[int, int, int]],
171
+ ) -> dict[tuple[int, int, int], int]:
172
+ """Indicates if each action is optimal according to a
173
+ :class:`~job_shop_lib.dispatching.OptimalOperationsObserver` instance.
174
+
175
+ Args:
176
+ optimal_ops_observer: The observer that provides optimal operations.
177
+ available_operations_with_ids: List of available operations with their
178
+ IDs (operation_id, machine_id, job_id).
179
+
180
+ Returns:
181
+ A dictionary mapping each tuple
182
+ (operation_id, machine_id, job_id) in the available actions to a binary
183
+ indicator (1 if optimal, 0 otherwise).
184
+ """
185
+ optimal_actions = {}
186
+ optimal_ops = optimal_ops_observer.optimal_available
187
+ optimal_ops_ids = [
188
+ (op.operation_id, op.machine_id, op.job_id) for op in optimal_ops
189
+ ]
190
+ for operation_id, machine_id, job_id in available_operations_with_ids:
191
+ is_optimal = (operation_id, machine_id, job_id) in optimal_ops_ids
192
+ optimal_actions[(operation_id, machine_id, job_id)] = int(is_optimal)
193
+ return optimal_actions
194
+
195
+
167
196
  if __name__ == "__main__":
168
197
  import doctest
169
198
 
@@ -3,7 +3,7 @@
3
3
  import os
4
4
  import pathlib
5
5
  import shutil
6
- from typing import Sequence, Protocol, Optional, List
6
+ from typing import Sequence, Protocol, Optional, List, Any
7
7
 
8
8
  import imageio
9
9
  import matplotlib.pyplot as plt
@@ -68,6 +68,7 @@ def get_partial_gantt_chart_plotter(
68
68
  title: Optional[str] = None,
69
69
  cmap: str = "viridis",
70
70
  show_available_operations: bool = False,
71
+ **kwargs: Any,
71
72
  ) -> PartialGanttChartPlotter:
72
73
  """Returns a function that plots a Gantt chart for an unfinished schedule.
73
74
 
@@ -76,6 +77,8 @@ def get_partial_gantt_chart_plotter(
76
77
  cmap: The name of the colormap to use.
77
78
  show_available_operations:
78
79
  Whether to show the available operations in the Gantt chart.
80
+ **kwargs: Additional keyword arguments to pass to the
81
+ :func:`plot_gantt_chart` function.
79
82
 
80
83
  Returns:
81
84
  A function that plots a Gantt chart for a schedule. The function takes
@@ -97,7 +100,7 @@ def get_partial_gantt_chart_plotter(
97
100
  current_time: Optional[int] = None,
98
101
  ) -> Figure:
99
102
  fig, ax = plot_gantt_chart(
100
- schedule, title=title, cmap_name=cmap, xlim=makespan
103
+ schedule, title=title, cmap_name=cmap, xlim=makespan, **kwargs
101
104
  )
102
105
 
103
106
  if show_available_operations and available_operations is not None:
@@ -7,6 +7,7 @@ import warnings
7
7
  import copy
8
8
 
9
9
  import matplotlib
10
+ import matplotlib.colors
10
11
  import matplotlib.pyplot as plt
11
12
  import networkx as nx
12
13
  from networkx.drawing.nx_agraph import graphviz_layout
@@ -66,6 +67,9 @@ def plot_disjunctive_graph(
66
67
  alpha: float = 0.95,
67
68
  operation_node_labeler: Callable[[Node], str] = duration_labeler,
68
69
  node_font_color: str = "white",
70
+ machine_colors: Optional[
71
+ Dict[int, Tuple[float, float, float, float]]
72
+ ] = None,
69
73
  color_map: str = "Dark2_r",
70
74
  disjunctive_edge_color: str = "red",
71
75
  conjunctive_edge_color: str = "black",
@@ -114,6 +118,12 @@ def plot_disjunctive_graph(
114
118
  with their duration.
115
119
  node_font_color:
116
120
  The color of the node labels (default is ``"white"``).
121
+ machine_colors:
122
+ A dictionary that maps machine ids to colors. If not provided,
123
+ the colors are generated using the ``color_map``. If provided,
124
+ the colors are used as the base for the node colors. The
125
+ dictionary should have the form ``{machine_id: (r, g, b, a)}``.
126
+ For source and sink nodes use ``-1`` as the machine id.
117
127
  color_map:
118
128
  The color map to use for the nodes (default is ``"Dark2_r"``).
119
129
  disjunctive_edge_color:
@@ -229,12 +239,40 @@ def plot_disjunctive_graph(
229
239
 
230
240
  # Draw nodes
231
241
  # ----------
232
- node_colors = [
233
- _get_node_color(node)
234
- for node in job_shop_graph.nodes
235
- if not job_shop_graph.is_removed(node.node_id)
236
- ]
237
- cmap_func = matplotlib.colormaps.get_cmap(color_map)
242
+ operation_nodes = job_shop_graph.nodes_by_type[NodeType.OPERATION]
243
+ cmap_func: Optional[matplotlib.colors.Colormap] = None
244
+ if machine_colors is None:
245
+ machine_colors = {}
246
+ cmap_func = matplotlib.colormaps.get_cmap(color_map)
247
+ remaining_machines = job_shop_graph.instance.num_machines
248
+ for operation_node in operation_nodes:
249
+ if job_shop_graph.is_removed(operation_node.node_id):
250
+ continue
251
+ machine_id = operation_node.operation.machine_id
252
+ if machine_id not in machine_colors:
253
+ machine_colors[machine_id] = cmap_func(
254
+ (_get_node_color(operation_node) + 1)
255
+ / job_shop_graph.instance.num_machines
256
+ )
257
+ remaining_machines -= 1
258
+ if remaining_machines == 0:
259
+ break
260
+ node_colors: list[Any] = [
261
+ _get_node_color(node)
262
+ for node in job_shop_graph.nodes
263
+ if not job_shop_graph.is_removed(node.node_id)
264
+ ]
265
+ else:
266
+ node_colors = []
267
+ for node in job_shop_graph.nodes:
268
+ if job_shop_graph.is_removed(node.node_id):
269
+ continue
270
+ if node.node_type == NodeType.OPERATION:
271
+ machine_id = node.operation.machine_id
272
+ else:
273
+ machine_id = -1
274
+ node_colors.append(machine_colors[machine_id])
275
+
238
276
  nx.draw_networkx_nodes(
239
277
  job_shop_graph.graph,
240
278
  pos,
@@ -292,24 +330,20 @@ def plot_disjunctive_graph(
292
330
 
293
331
  # Draw node labels
294
332
  # ----------------
295
- operation_nodes = job_shop_graph.nodes_by_type[NodeType.OPERATION]
296
333
  labels = {}
297
- source_node = job_shop_graph.nodes_by_type[NodeType.SOURCE][0]
298
- labels[source_node] = start_node_label
299
-
300
- sink_node = job_shop_graph.nodes_by_type[NodeType.SINK][0]
301
- labels[sink_node] = end_node_label
302
- machine_colors: dict[int, Tuple[float, float, float, float]] = {}
334
+ if job_shop_graph.nodes_by_type[NodeType.SOURCE]:
335
+ source_node = job_shop_graph.nodes_by_type[NodeType.SOURCE][0]
336
+ if not job_shop_graph.is_removed(source_node.node_id):
337
+ labels[source_node] = start_node_label
338
+ if job_shop_graph.nodes_by_type[NodeType.SINK]:
339
+ sink_node = job_shop_graph.nodes_by_type[NodeType.SINK][0]
340
+ # check if the sink node is removed
341
+ if not job_shop_graph.is_removed(sink_node.node_id):
342
+ labels[sink_node] = end_node_label
303
343
  for operation_node in operation_nodes:
304
344
  if job_shop_graph.is_removed(operation_node.node_id):
305
345
  continue
306
346
  labels[operation_node] = operation_node_labeler(operation_node)
307
- machine_id = operation_node.operation.machine_id
308
- if machine_id not in machine_colors:
309
- machine_colors[machine_id] = cmap_func(
310
- (_get_node_color(operation_node) + 1)
311
- / job_shop_graph.instance.num_machines
312
- )
313
347
 
314
348
  nx.draw_networkx_labels(
315
349
  job_shop_graph.graph,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: job-shop-lib
3
- Version: 1.0.0b4
3
+ Version: 1.0.1
4
4
  Summary: An easy-to-use and modular Python library for the Job Shop Scheduling Problem (JSSP)
5
5
  License: MIT
6
6
  Author: Pabloo22
@@ -12,7 +12,7 @@ Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
13
  Classifier: Programming Language :: Python :: 3.12
14
14
  Provides-Extra: pygraphviz
15
- Requires-Dist: gymnasium (>=0.29.1,<0.30.0)
15
+ Requires-Dist: gymnasium (>=1.0.0,<2.0.0)
16
16
  Requires-Dist: imageio[ffmpeg] (>=2.34.1,<3.0.0)
17
17
  Requires-Dist: matplotlib (>=3,<4)
18
18
  Requires-Dist: networkx (>=3,<4)
@@ -48,7 +48,7 @@ See the [documentation](https://job-shop-lib.readthedocs.io/en/latest/) for more
48
48
 
49
49
  JobShopLib is distributed on [PyPI](https://pypi.org/project/job-shop-lib/) and it supports Python 3.10+.
50
50
 
51
- You can install the latest stable version (version 0.5.1) using `pip`:
51
+ You can install the latest stable version using `pip`:
52
52
 
53
53
  ```bash
54
54
  pip install job-shop-lib
@@ -57,13 +57,7 @@ pip install job-shop-lib
57
57
  See [this](https://colab.research.google.com/drive/1XV_Rvq1F2ns6DFG8uNj66q_rcowwTZ4H?usp=sharing) Google Colab notebook for a quick start guide!
58
58
 
59
59
 
60
- Version 1.0.0 is currently in beta stage and can be installed with:
61
-
62
- ```bash
63
- pip install job-shop-lib==1.0.0b4
64
- ```
65
-
66
- Although this version is not stable and may contain breaking changes in subsequent releases, it is recommended to install it to access the new reinforcement learning environments and familiarize yourself with new changes (see the [latest pull requests](https://github.com/Pabloo22/job_shop_lib/pulls?q=is%3Apr+is%3Aclosed)). There is a [documentation page](https://job-shop-lib.readthedocs.io/en/latest/) for versions 1.0.0a3 and onward.
60
+ There is a [documentation page](https://job-shop-lib.readthedocs.io/en/latest/) for versions 1.0.0a3 and onward. See see the [latest pull requests](https://github.com/Pabloo22/job_shop_lib/pulls?q=is%3Apr+is%3Aclosed) for the latest changes.
67
61
 
68
62
  <!-- end installation -->
69
63