job-shop-lib 0.5.1__py3-none-any.whl → 1.0.0a1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (76) hide show
  1. job_shop_lib/__init__.py +16 -8
  2. job_shop_lib/{base_solver.py → _base_solver.py} +1 -1
  3. job_shop_lib/{job_shop_instance.py → _job_shop_instance.py} +9 -4
  4. job_shop_lib/_operation.py +95 -0
  5. job_shop_lib/{schedule.py → _schedule.py} +73 -54
  6. job_shop_lib/{scheduled_operation.py → _scheduled_operation.py} +13 -37
  7. job_shop_lib/benchmarking/__init__.py +66 -43
  8. job_shop_lib/benchmarking/_load_benchmark.py +88 -0
  9. job_shop_lib/constraint_programming/__init__.py +13 -0
  10. job_shop_lib/{cp_sat/ortools_solver.py → constraint_programming/_ortools_solver.py} +57 -18
  11. job_shop_lib/dispatching/__init__.py +45 -41
  12. job_shop_lib/dispatching/{dispatcher.py → _dispatcher.py} +153 -80
  13. job_shop_lib/dispatching/_dispatcher_observer_config.py +54 -0
  14. job_shop_lib/dispatching/_factories.py +125 -0
  15. job_shop_lib/dispatching/{history_tracker.py → _history_observer.py} +4 -6
  16. job_shop_lib/dispatching/{pruning_functions.py → _ready_operation_filters.py} +6 -35
  17. job_shop_lib/dispatching/_unscheduled_operations_observer.py +69 -0
  18. job_shop_lib/dispatching/feature_observers/__init__.py +16 -10
  19. job_shop_lib/dispatching/feature_observers/{composite_feature_observer.py → _composite_feature_observer.py} +84 -2
  20. job_shop_lib/dispatching/feature_observers/{duration_observer.py → _duration_observer.py} +6 -17
  21. job_shop_lib/dispatching/feature_observers/{earliest_start_time_observer.py → _earliest_start_time_observer.py} +114 -35
  22. job_shop_lib/dispatching/feature_observers/{factory.py → _factory.py} +31 -5
  23. job_shop_lib/dispatching/feature_observers/{feature_observer.py → _feature_observer.py} +59 -16
  24. job_shop_lib/dispatching/feature_observers/_is_completed_observer.py +97 -0
  25. job_shop_lib/dispatching/feature_observers/_is_ready_observer.py +33 -0
  26. job_shop_lib/dispatching/feature_observers/{position_in_job_observer.py → _position_in_job_observer.py} +1 -8
  27. job_shop_lib/dispatching/feature_observers/{remaining_operations_observer.py → _remaining_operations_observer.py} +8 -26
  28. job_shop_lib/dispatching/rules/__init__.py +51 -0
  29. job_shop_lib/dispatching/rules/_dispatching_rule_factory.py +82 -0
  30. job_shop_lib/dispatching/{dispatching_rule_solver.py → rules/_dispatching_rule_solver.py} +44 -15
  31. job_shop_lib/dispatching/{dispatching_rules.py → rules/_dispatching_rules_functions.py} +74 -21
  32. job_shop_lib/dispatching/rules/_machine_chooser_factory.py +69 -0
  33. job_shop_lib/dispatching/rules/_utils.py +127 -0
  34. job_shop_lib/exceptions.py +18 -0
  35. job_shop_lib/generation/__init__.py +2 -2
  36. job_shop_lib/generation/{general_instance_generator.py → _general_instance_generator.py} +26 -7
  37. job_shop_lib/generation/{instance_generator.py → _instance_generator.py} +13 -3
  38. job_shop_lib/graphs/__init__.py +17 -6
  39. job_shop_lib/graphs/{job_shop_graph.py → _job_shop_graph.py} +81 -2
  40. job_shop_lib/graphs/{node.py → _node.py} +18 -12
  41. job_shop_lib/graphs/graph_updaters/__init__.py +13 -0
  42. job_shop_lib/graphs/graph_updaters/_graph_updater.py +59 -0
  43. job_shop_lib/graphs/graph_updaters/_residual_graph_updater.py +154 -0
  44. job_shop_lib/graphs/graph_updaters/_utils.py +25 -0
  45. job_shop_lib/reinforcement_learning/__init__.py +41 -0
  46. job_shop_lib/reinforcement_learning/_multi_job_shop_graph_env.py +366 -0
  47. job_shop_lib/reinforcement_learning/_reward_observers.py +85 -0
  48. job_shop_lib/reinforcement_learning/_single_job_shop_graph_env.py +337 -0
  49. job_shop_lib/reinforcement_learning/_types_and_constants.py +61 -0
  50. job_shop_lib/reinforcement_learning/_utils.py +96 -0
  51. job_shop_lib/visualization/__init__.py +20 -4
  52. job_shop_lib/visualization/{agent_task_graph.py → _agent_task_graph.py} +28 -9
  53. job_shop_lib/visualization/_gantt_chart_creator.py +219 -0
  54. job_shop_lib/visualization/_gantt_chart_video_and_gif_creation.py +388 -0
  55. {job_shop_lib-0.5.1.dist-info → job_shop_lib-1.0.0a1.dist-info}/METADATA +68 -44
  56. job_shop_lib-1.0.0a1.dist-info/RECORD +66 -0
  57. job_shop_lib/benchmarking/load_benchmark.py +0 -142
  58. job_shop_lib/cp_sat/__init__.py +0 -5
  59. job_shop_lib/dispatching/factories.py +0 -206
  60. job_shop_lib/dispatching/feature_observers/is_completed_observer.py +0 -98
  61. job_shop_lib/dispatching/feature_observers/is_ready_observer.py +0 -40
  62. job_shop_lib/generators/__init__.py +0 -8
  63. job_shop_lib/generators/basic_generator.py +0 -200
  64. job_shop_lib/generators/transformations.py +0 -164
  65. job_shop_lib/operation.py +0 -122
  66. job_shop_lib/visualization/create_gif.py +0 -209
  67. job_shop_lib-0.5.1.dist-info/RECORD +0 -52
  68. /job_shop_lib/dispatching/feature_observers/{is_scheduled_observer.py → _is_scheduled_observer.py} +0 -0
  69. /job_shop_lib/generation/{transformations.py → _transformations.py} +0 -0
  70. /job_shop_lib/graphs/{build_agent_task_graph.py → _build_agent_task_graph.py} +0 -0
  71. /job_shop_lib/graphs/{build_disjunctive_graph.py → _build_disjunctive_graph.py} +0 -0
  72. /job_shop_lib/graphs/{constants.py → _constants.py} +0 -0
  73. /job_shop_lib/visualization/{disjunctive_graph.py → _disjunctive_graph.py} +0 -0
  74. /job_shop_lib/visualization/{gantt_chart.py → _gantt_chart.py} +0 -0
  75. {job_shop_lib-0.5.1.dist-info → job_shop_lib-1.0.0a1.dist-info}/LICENSE +0 -0
  76. {job_shop_lib-0.5.1.dist-info → job_shop_lib-1.0.0a1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,337 @@
1
+ """Home of the `SingleJobShopGraphEnv` class."""
2
+
3
+ from copy import deepcopy
4
+ from collections.abc import Callable, Sequence
5
+ from typing import Any
6
+
7
+ import matplotlib.pyplot as plt
8
+ import gymnasium as gym
9
+ import numpy as np
10
+ from numpy.typing import NDArray
11
+
12
+ from job_shop_lib import JobShopInstance, Operation
13
+ from job_shop_lib.graphs import JobShopGraph
14
+ from job_shop_lib.graphs.graph_updaters import (
15
+ GraphUpdater,
16
+ ResidualGraphUpdater,
17
+ )
18
+ from job_shop_lib.dispatching import (
19
+ Dispatcher,
20
+ filter_dominated_operations,
21
+ DispatcherObserverConfig,
22
+ )
23
+ from job_shop_lib.dispatching.feature_observers import (
24
+ FeatureObserverConfig,
25
+ CompositeFeatureObserver,
26
+ )
27
+ from job_shop_lib.visualization import GanttChartCreator
28
+ from job_shop_lib.reinforcement_learning import (
29
+ RewardObserver,
30
+ MakespanReward,
31
+ add_padding,
32
+ RenderConfig,
33
+ ObservationSpaceKey,
34
+ ObservationDict,
35
+ )
36
+
37
+
38
+ class SingleJobShopGraphEnv(gym.Env):
39
+ """A Gymnasium environment for solving a specific instance of the Job Shop
40
+ Scheduling Problem represented as a graph.
41
+
42
+ This environment manages the scheduling process for a single Job Shop
43
+ instance, using a graph representation and various observers to track the
44
+ state and compute rewards.
45
+
46
+ Observation Space:
47
+ A dictionary with the following keys:
48
+ - "removed_nodes": Binary vector indicating removed graph nodes.
49
+ - "edge_list": Matrix of graph edges in COO format.
50
+ - Feature matrices: Keys corresponding to the composite observer
51
+ features (e.g., "operations", "jobs", "machines").
52
+
53
+ Action Space:
54
+ MultiDiscrete space representing (job_id, machine_id) pairs.
55
+
56
+ Render Modes:
57
+ - "human": Displays the current Gantt chart.
58
+ - "save_video": Saves a video of the complete Gantt chart.
59
+ - "save_gif": Saves a GIF of the complete Gantt chart.
60
+
61
+ Attributes:
62
+ dispatcher:
63
+ Manages the scheduling process. See
64
+ :class:`~job_shop_lib.dispatching.Dispatcher`.
65
+ composite_observer:
66
+ A :class:`~job_shop_lib.dispatching.feature_observers.
67
+ CompositeFeatureObserver` which aggregates features from multiple
68
+ observers.
69
+ graph_updater:
70
+ Updates the graph representation after each action. See
71
+ :class:`~job_shop_lib.graphs.GraphUpdater`.
72
+ reward_function:
73
+ Computes rewards for actions taken. See
74
+ :class:`~job_shop_lib.reinforcement_learning.RewardObserver`.
75
+ action_space:
76
+ Defines the action space. The action is a tuple of two integers
77
+ (job_id, machine_id). The machine_id can be -1 if the selected
78
+ operation can only be scheduled in one machine.
79
+ observation_space:
80
+ Defines the observation space. The observation is a dictionary
81
+ with the following keys:
82
+ - "removed_nodes": Binary vector indicating removed graph nodes.
83
+ - "edge_list": Matrix of graph edges in COO format.
84
+ - Feature matrices: Keys corresponding to the composite observer
85
+ features (e.g., "operations", "jobs", "machines").
86
+ render_mode:
87
+ The mode for rendering the environment ("human", "save_video",
88
+ "save_gif").
89
+ gantt_chart_creator:
90
+ Creates Gantt chart visualizations. See
91
+ :class:`~job_shop_lib.visualization.GanttChartCreator`.
92
+ use_padding:
93
+ Whether to use padding in observations. Padding maintains the
94
+ """
95
+
96
+ metadata = {"render_modes": ["human", "save_video", "save_gif"]}
97
+
98
+ # I think the class is easier to use this way. We could initiliaze the
99
+ # class from Dispatcher or an already initialized RewardFunction. However,
100
+ # it would be impossible to add good default values.
101
+ # pylint: disable=too-many-arguments
102
+ def __init__(
103
+ self,
104
+ job_shop_graph: JobShopGraph,
105
+ feature_observer_configs: Sequence[FeatureObserverConfig],
106
+ reward_function_config: DispatcherObserverConfig[
107
+ type[RewardObserver]
108
+ ] = DispatcherObserverConfig(class_type=MakespanReward),
109
+ graph_updater_config: DispatcherObserverConfig[
110
+ type[GraphUpdater]
111
+ ] = DispatcherObserverConfig(class_type=ResidualGraphUpdater),
112
+ ready_operations_filter: (
113
+ Callable[[Dispatcher, list[Operation]], list[Operation]] | None
114
+ ) = filter_dominated_operations,
115
+ render_mode: str | None = None,
116
+ render_config: RenderConfig | None = None,
117
+ use_padding: bool = True,
118
+ ) -> None:
119
+ """Initializes the SingleJobShopGraphEnv environment.
120
+
121
+ Args:
122
+ job_shop_graph:
123
+ The JobShopGraph instance representing the job shop problem.
124
+ feature_observer_configs:
125
+ A list of FeatureObserverConfig instances for the feature
126
+ observers.
127
+ reward_function_config:
128
+ The configuration for the reward function.
129
+ graph_updater_config:
130
+ The configuration for the graph updater.
131
+ ready_operations_filter:
132
+ The function to use for pruning dominated operations.
133
+ render_mode:
134
+ The mode for rendering the environment ("human", "save_video",
135
+ "save_gif").
136
+ render_config:
137
+ Configuration for rendering (e.g., paths for saving videos
138
+ or GIFs).
139
+ use_padding:
140
+ Whether to use padding for the edge index.
141
+ """
142
+ super().__init__()
143
+ # Used for resetting the environment
144
+ self.initial_job_shop_graph = deepcopy(job_shop_graph)
145
+
146
+ self.dispatcher = Dispatcher(
147
+ job_shop_graph.instance,
148
+ ready_operations_filter=ready_operations_filter,
149
+ )
150
+
151
+ # Observers added to track the environment state
152
+ self.composite_observer = (
153
+ CompositeFeatureObserver.from_feature_observer_configs(
154
+ self.dispatcher, feature_observer_configs
155
+ )
156
+ )
157
+ self.graph_updater = graph_updater_config.class_type(
158
+ dispatcher=self.dispatcher,
159
+ job_shop_graph=job_shop_graph,
160
+ **graph_updater_config.kwargs,
161
+ )
162
+ self.reward_function = reward_function_config.class_type(
163
+ dispatcher=self.dispatcher, **reward_function_config.kwargs
164
+ )
165
+ self.action_space = gym.spaces.MultiDiscrete(
166
+ [self.instance.num_jobs, self.instance.num_machines], start=[0, -1]
167
+ )
168
+ self.observation_space: gym.spaces.Dict = self._get_observation_space()
169
+ self.render_mode = render_mode
170
+ if render_config is None:
171
+ render_config = {}
172
+ self.gantt_chart_creator = GanttChartCreator(
173
+ dispatcher=self.dispatcher, **render_config
174
+ )
175
+ self.use_padding = use_padding
176
+
177
+ @property
178
+ def instance(self) -> JobShopInstance:
179
+ """Returns the instance the environment is working on."""
180
+ return self.job_shop_graph.instance
181
+
182
+ @property
183
+ def job_shop_graph(self) -> JobShopGraph:
184
+ """Returns the job shop graph."""
185
+ return self.graph_updater.job_shop_graph
186
+
187
+ def _get_observation_space(self) -> gym.spaces.Dict:
188
+ """Returns the observation space dictionary."""
189
+ num_edges = self.job_shop_graph.num_edges
190
+ dict_space: dict[str, gym.Space] = {
191
+ ObservationSpaceKey.REMOVED_NODES.value: gym.spaces.MultiBinary(
192
+ len(self.job_shop_graph.nodes)
193
+ ),
194
+ ObservationSpaceKey.EDGE_INDEX.value: gym.spaces.MultiDiscrete(
195
+ np.full(
196
+ (2, num_edges),
197
+ fill_value=len(self.job_shop_graph.nodes) + 1,
198
+ dtype=np.int32,
199
+ ),
200
+ start=np.full(
201
+ (2, num_edges),
202
+ fill_value=-1, # -1 is used for padding
203
+ dtype=np.int32,
204
+ ),
205
+ ),
206
+ }
207
+ for feature_type, matrix in self.composite_observer.features.items():
208
+ dict_space[feature_type.value] = gym.spaces.Box(
209
+ low=-np.inf, high=np.inf, shape=matrix.shape
210
+ )
211
+ return gym.spaces.Dict(dict_space)
212
+
213
+ def reset(
214
+ self,
215
+ *,
216
+ seed: int | None = None,
217
+ options: dict[str, Any] | None = None,
218
+ ) -> tuple[ObservationDict, dict]:
219
+ """Resets the environment."""
220
+ super().reset(seed=seed, options=options)
221
+ self.dispatcher.reset()
222
+ obs = self.get_observation()
223
+ return obs, {}
224
+
225
+ def step(
226
+ self, action: tuple[int, int]
227
+ ) -> tuple[ObservationDict, float, bool, bool, dict[str, Any]]:
228
+ """Takes a step in the environment.
229
+
230
+ Args:
231
+ action:
232
+ The action to take. The action is a tuple of two integers
233
+ (job_id, machine_id):
234
+ the job ID and the machine ID in which to schedule the
235
+ operation.
236
+
237
+ Returns:
238
+ A tuple containing the following elements:
239
+ - The observation of the environment.
240
+ - The reward obtained.
241
+ - Whether the environment is done.
242
+ - Whether the episode was truncated (always False).
243
+ - A dictionary with additional information. The dictionary
244
+ contains the following keys:
245
+ - "feature_names": The names of the features in the
246
+ observation.
247
+ - "available_operations": The operations that are ready to be
248
+ scheduled.
249
+ """
250
+ job_id, machine_id = action
251
+ operation = self.dispatcher.next_operation(job_id)
252
+ if machine_id == -1:
253
+ machine_id = operation.machine_id
254
+
255
+ self.dispatcher.dispatch(operation, machine_id)
256
+
257
+ obs = self.get_observation()
258
+ reward = self.reward_function.last_reward
259
+ done = self.dispatcher.schedule.is_complete()
260
+ truncated = False
261
+ info: dict[str, Any] = {
262
+ "feature_names": self.composite_observer.column_names,
263
+ "available_operations": self.dispatcher.ready_operations(),
264
+ }
265
+ return obs, reward, done, truncated, info
266
+
267
+ def get_observation(self) -> ObservationDict:
268
+ """Returns the current observation of the environment."""
269
+ observation: ObservationDict = {
270
+ ObservationSpaceKey.REMOVED_NODES.value: np.array(
271
+ self.job_shop_graph.removed_nodes, dtype=bool
272
+ ),
273
+ ObservationSpaceKey.EDGE_INDEX.value: self._get_edge_index(),
274
+ }
275
+ for feature_type, matrix in self.composite_observer.features.items():
276
+ observation[feature_type.value] = matrix
277
+ return observation
278
+
279
+ def _get_edge_index(self) -> NDArray[np.int32]:
280
+ """Returns the edge index matrix."""
281
+ edge_index = np.array(
282
+ self.job_shop_graph.graph.edges(), dtype=np.int32
283
+ ).T
284
+
285
+ if self.use_padding:
286
+ output_shape = self.observation_space[
287
+ ObservationSpaceKey.EDGE_INDEX.value
288
+ ].shape
289
+ assert output_shape is not None # For the type checker
290
+ edge_index = add_padding(
291
+ edge_index, output_shape=output_shape, dtype=np.int32
292
+ )
293
+ return edge_index
294
+
295
+ def render(self):
296
+ """Renders the environment.
297
+
298
+ The rendering mode is set by the `render_mode` attribute:
299
+
300
+ - human: Renders the current Gannt chart.
301
+ - save_video: Saves a video of the Gantt chart. Used only if the
302
+ schedule is completed.
303
+ - save_gif: Saves a GIF of the Gantt chart. Used only if the schedule
304
+ is completed.
305
+ """
306
+ if self.render_mode == "human":
307
+ self.gantt_chart_creator.plot_gantt_chart()
308
+ plt.show(block=False)
309
+ elif self.render_mode == "save_video":
310
+ self.gantt_chart_creator.create_video()
311
+ elif self.render_mode == "save_gif":
312
+ self.gantt_chart_creator.create_gif()
313
+
314
+
315
+ if __name__ == "__main__":
316
+ from job_shop_lib.dispatching.feature_observers import (
317
+ FeatureObserverType,
318
+ FeatureType,
319
+ )
320
+ from job_shop_lib.graphs import build_disjunctive_graph
321
+ from job_shop_lib.benchmarking import load_benchmark_instance
322
+
323
+ instance = load_benchmark_instance("ft06")
324
+ job_shop_graph_ = build_disjunctive_graph(instance)
325
+ feature_observer_configs_ = [
326
+ DispatcherObserverConfig(
327
+ FeatureObserverType.IS_READY,
328
+ kwargs={"feature_types": [FeatureType.JOBS]},
329
+ )
330
+ ]
331
+
332
+ env = SingleJobShopGraphEnv(
333
+ job_shop_graph=job_shop_graph_,
334
+ feature_observer_configs=feature_observer_configs_,
335
+ render_mode="save_video",
336
+ render_config={"video_config": {"fps": 4}},
337
+ )
@@ -0,0 +1,61 @@
1
+ """Contains types and enumerations used in the reinforcement learning
2
+ module."""
3
+
4
+ from enum import Enum
5
+ from typing import TypedDict
6
+
7
+ import numpy as np
8
+
9
+ from job_shop_lib.dispatching.feature_observers import FeatureType
10
+ from job_shop_lib.visualization import (
11
+ GanttChartWrapperConfig,
12
+ GifConfig,
13
+ VideoConfig,
14
+ )
15
+
16
+
17
+ class RenderConfig(TypedDict, total=False):
18
+ """Configuration needed to initialize the `GanttChartCreator` class."""
19
+
20
+ gantt_chart_wrapper_config: GanttChartWrapperConfig
21
+ video_config: VideoConfig
22
+ gif_config: GifConfig
23
+
24
+
25
+ class ObservationSpaceKey(str, Enum):
26
+ """Enumeration of the keys for the observation space dictionary."""
27
+
28
+ REMOVED_NODES = "removed_nodes"
29
+ EDGE_INDEX = "edge_index"
30
+ OPERATIONS = FeatureType.OPERATIONS.value
31
+ JOBS = FeatureType.JOBS.value
32
+ MACHINES = FeatureType.MACHINES.value
33
+
34
+
35
+ class _ObservationDictRequired(TypedDict):
36
+ """Required fields for the observation dictionary."""
37
+
38
+ removed_nodes: np.ndarray
39
+ edge_index: np.ndarray
40
+
41
+
42
+ class _ObservationDictOptional(TypedDict, total=False):
43
+ """Optional fields for the observation dictionary."""
44
+
45
+ operations: np.ndarray
46
+ jobs: np.ndarray
47
+ machines: np.ndarray
48
+
49
+
50
+ class ObservationDict(_ObservationDictRequired, _ObservationDictOptional):
51
+ """A dictionary containing the observation of the environment.
52
+
53
+ Required fields:
54
+ removed_nodes (np.ndarray): Binary vector indicating removed nodes.
55
+ edge_index (np.ndarray): Edge list in COO format.
56
+
57
+ Optional fields:
58
+ operations (np.ndarray): Matrix of operation features.
59
+ jobs (np.ndarray): Matrix of job features.
60
+ machines (np.ndarray): Matrix of machine features.
61
+ """
@@ -0,0 +1,96 @@
1
+ """Utility functions for reinforcement learning."""
2
+
3
+ from typing import TypeVar, Any
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+
8
+ from job_shop_lib.exceptions import ValidationError
9
+
10
+ T = TypeVar("T", bound=np.number)
11
+
12
+
13
+ def add_padding(
14
+ array: NDArray[Any],
15
+ output_shape: tuple[int, ...],
16
+ padding_value: float = -1,
17
+ dtype: type[T] | None = None,
18
+ ) -> NDArray[T]:
19
+ """Adds padding to the array.
20
+
21
+ Pads the input array to the specified output shape with a given padding
22
+ value. If the ``dtype`` is not specified, the ``dtype`` of the input array
23
+ is used.
24
+
25
+ Args:
26
+ array:
27
+ The input array to be padded.
28
+ output_shape:
29
+ The desired shape of the output array.
30
+ padding_value:
31
+ The value to use for padding. Defaults to -1.
32
+ dtype:
33
+ The data type for the output array. Defaults to ``None``, in which
34
+ case the dtype of the input array is used.
35
+
36
+ Returns:
37
+ The padded array with the specified output shape.
38
+
39
+ Raises:
40
+ ValidationError:
41
+ If the output shape is smaller than the input shape.
42
+
43
+ Examples:
44
+
45
+ .. doctest::
46
+
47
+ >>> array = np.array([[1, 2], [3, 4]])
48
+ >>> add_padding(array, (3, 3))
49
+ array([[ 1, 2, -1],
50
+ [ 3, 4, -1],
51
+ [-1, -1, -1]])
52
+
53
+ >>> add_padding(array, (3, 3), padding_value=0)
54
+ array([[1, 2, 0],
55
+ [3, 4, 0],
56
+ [0, 0, 0]])
57
+
58
+ >>> bool_array = np.array([[True, False], [False, True]])
59
+ >>> add_padding(bool_array, (3, 3), padding_value=False, dtype=int)
60
+ array([[1, 0, 0],
61
+ [0, 1, 0],
62
+ [0, 0, 0]])
63
+
64
+ >>> add_padding(bool_array, (3, 3), dtype=int)
65
+ array([[ 1, 0, -1],
66
+ [ 0, 1, -1],
67
+ [-1, -1, -1]])
68
+ """
69
+
70
+ if np.any(np.less(output_shape, array.shape)):
71
+ raise ValidationError(
72
+ "Output shape must be greater than the input shape. "
73
+ f"Got output shape: {output_shape}, input shape: {array.shape}."
74
+ )
75
+
76
+ if dtype is None:
77
+ dtype = array.dtype.type
78
+
79
+ padded_array = np.full(
80
+ output_shape,
81
+ fill_value=padding_value,
82
+ dtype=dtype,
83
+ )
84
+
85
+ if array.size == 0:
86
+ return padded_array
87
+
88
+ slices = tuple(slice(0, dim) for dim in array.shape)
89
+ padded_array[slices] = array
90
+ return padded_array
91
+
92
+
93
+ if __name__ == "__main__":
94
+ import doctest
95
+
96
+ doctest.testmod()
@@ -1,25 +1,41 @@
1
1
  """Package for visualization."""
2
2
 
3
- from job_shop_lib.visualization.gantt_chart import plot_gantt_chart
4
- from job_shop_lib.visualization.create_gif import (
3
+ from job_shop_lib.visualization._gantt_chart import plot_gantt_chart
4
+ from job_shop_lib.visualization._gantt_chart_video_and_gif_creation import (
5
5
  create_gif,
6
+ create_gantt_chart_video,
6
7
  create_gantt_chart_frames,
7
8
  plot_gantt_chart_wrapper,
9
+ create_video_from_frames,
8
10
  create_gif_from_frames,
9
11
  )
10
- from job_shop_lib.visualization.disjunctive_graph import plot_disjunctive_graph
11
- from job_shop_lib.visualization.agent_task_graph import (
12
+ from job_shop_lib.visualization._disjunctive_graph import (
13
+ plot_disjunctive_graph,
14
+ )
15
+ from job_shop_lib.visualization._agent_task_graph import (
12
16
  plot_agent_task_graph,
13
17
  three_columns_layout,
14
18
  )
19
+ from job_shop_lib.visualization._gantt_chart_creator import (
20
+ GanttChartCreator,
21
+ GanttChartWrapperConfig,
22
+ GifConfig,
23
+ VideoConfig,
24
+ )
15
25
 
16
26
  __all__ = [
17
27
  "plot_gantt_chart",
28
+ "create_gantt_chart_video",
18
29
  "create_gif",
19
30
  "create_gantt_chart_frames",
20
31
  "plot_gantt_chart_wrapper",
21
32
  "create_gif_from_frames",
33
+ "create_video_from_frames",
22
34
  "plot_disjunctive_graph",
23
35
  "plot_agent_task_graph",
24
36
  "three_columns_layout",
37
+ "GanttChartCreator",
38
+ "GanttChartWrapperConfig",
39
+ "GifConfig",
40
+ "VideoConfig",
25
41
  ]
@@ -51,6 +51,7 @@ def plot_agent_task_graph(
51
51
 
52
52
  # Create the networkx graph
53
53
  graph = job_shop_graph.graph
54
+ nodes = job_shop_graph.non_removed_nodes()
54
55
 
55
56
  # Create the layout if it was not provided
56
57
  if layout is None:
@@ -64,16 +65,17 @@ def plot_agent_task_graph(
64
65
  job_shop_graph.nodes_by_type[NodeType.MACHINE]
65
66
  )
66
67
  }
68
+
67
69
  node_colors = [
68
70
  _get_node_color(node, machine_colors) for node in job_shop_graph.nodes
69
- ]
71
+ ] # We need to get the color of all nodes to avoid an index error
70
72
  node_shapes = {"machine": "s", "job": "d", "operation": "o", "global": "o"}
71
73
 
72
74
  # Draw nodes with different shapes based on their type
73
75
  for node_type, shape in node_shapes.items():
74
76
  current_nodes = [
75
77
  node.node_id
76
- for node in job_shop_graph.nodes
78
+ for node in nodes
77
79
  if node.node_type.name.lower() == node_type
78
80
  ]
79
81
  nx.draw_networkx_nodes(
@@ -90,9 +92,7 @@ def plot_agent_task_graph(
90
92
  # Draw edges
91
93
  nx.draw_networkx_edges(graph, layout, ax=ax)
92
94
 
93
- node_labels = {
94
- node.node_id: _get_node_label(node) for node in job_shop_graph.nodes
95
- }
95
+ node_labels = {node.node_id: _get_node_label(node) for node in nodes}
96
96
  nx.draw_networkx_labels(graph, layout, node_labels, ax=ax)
97
97
 
98
98
  ax.set_axis_off()
@@ -181,10 +181,29 @@ def three_columns_layout(
181
181
 
182
182
  x_positions = _get_x_positions(leftmost_position, rightmost_position)
183
183
 
184
- operation_nodes = job_shop_graph.nodes_by_type[NodeType.OPERATION]
185
- machine_nodes = job_shop_graph.nodes_by_type[NodeType.MACHINE]
186
- job_nodes = job_shop_graph.nodes_by_type[NodeType.JOB]
187
- global_nodes = job_shop_graph.nodes_by_type[NodeType.GLOBAL]
184
+ operation_nodes = [
185
+ node
186
+ for node in job_shop_graph.nodes_by_type[NodeType.OPERATION]
187
+ if not job_shop_graph.is_removed(node)
188
+ ]
189
+ machine_nodes = [
190
+ node
191
+ for node in job_shop_graph.nodes_by_type[NodeType.MACHINE]
192
+ if not job_shop_graph.is_removed(node)
193
+ ]
194
+ job_nodes = [
195
+ node
196
+ for node in job_shop_graph.nodes_by_type[NodeType.JOB]
197
+ if not job_shop_graph.is_removed(node)
198
+ ]
199
+ global_nodes = [
200
+ node
201
+ for node in job_shop_graph.nodes_by_type[NodeType.GLOBAL]
202
+ if not job_shop_graph.is_removed(node)
203
+ ]
204
+
205
+ # job_nodes = job_shop_graph.nodes_by_type[NodeType.JOB]
206
+ # global_nodes = job_shop_graph.nodes_by_type[NodeType.GLOBAL]
188
207
 
189
208
  total_positions = len(operation_nodes) + len(global_nodes) * 2
190
209
  y_spacing = (topmost_position - bottommost_position) / total_positions