manim 0.18.0__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of manim might be problematic. Click here for more details.

Files changed (115) hide show
  1. manim/__init__.py +3 -6
  2. manim/__main__.py +18 -10
  3. manim/_config/__init__.py +5 -2
  4. manim/_config/cli_colors.py +12 -8
  5. manim/_config/default.cfg +1 -1
  6. manim/_config/logger_utils.py +9 -8
  7. manim/_config/utils.py +637 -449
  8. manim/animation/animation.py +9 -2
  9. manim/animation/composition.py +78 -40
  10. manim/animation/creation.py +12 -6
  11. manim/animation/fading.py +0 -1
  12. manim/animation/indication.py +10 -21
  13. manim/animation/movement.py +1 -2
  14. manim/animation/rotation.py +1 -1
  15. manim/animation/specialized.py +1 -1
  16. manim/animation/speedmodifier.py +7 -2
  17. manim/animation/transform_matching_parts.py +1 -1
  18. manim/camera/camera.py +13 -4
  19. manim/cli/cfg/group.py +18 -8
  20. manim/cli/checkhealth/checks.py +2 -0
  21. manim/cli/checkhealth/commands.py +2 -0
  22. manim/cli/default_group.py +13 -5
  23. manim/cli/init/commands.py +4 -1
  24. manim/cli/plugins/commands.py +3 -0
  25. manim/cli/render/commands.py +27 -20
  26. manim/cli/render/ease_of_access_options.py +4 -3
  27. manim/cli/render/global_options.py +9 -7
  28. manim/cli/render/output_options.py +6 -5
  29. manim/cli/render/render_options.py +13 -13
  30. manim/constants.py +54 -15
  31. manim/gui/gui.py +2 -0
  32. manim/mobject/geometry/arc.py +4 -4
  33. manim/mobject/geometry/boolean_ops.py +13 -9
  34. manim/mobject/geometry/line.py +16 -8
  35. manim/mobject/geometry/polygram.py +17 -5
  36. manim/mobject/geometry/tips.py +2 -2
  37. manim/mobject/graph.py +379 -106
  38. manim/mobject/graphing/coordinate_systems.py +17 -20
  39. manim/mobject/graphing/functions.py +14 -10
  40. manim/mobject/graphing/number_line.py +1 -1
  41. manim/mobject/mobject.py +175 -72
  42. manim/mobject/opengl/opengl_compatibility.py +2 -0
  43. manim/mobject/opengl/opengl_geometry.py +26 -1
  44. manim/mobject/opengl/opengl_image_mobject.py +2 -0
  45. manim/mobject/opengl/opengl_mobject.py +3 -0
  46. manim/mobject/opengl/opengl_point_cloud_mobject.py +2 -0
  47. manim/mobject/opengl/opengl_surface.py +2 -0
  48. manim/mobject/opengl/opengl_three_dimensions.py +2 -0
  49. manim/mobject/opengl/opengl_vectorized_mobject.py +19 -14
  50. manim/mobject/svg/brace.py +2 -0
  51. manim/mobject/svg/svg_mobject.py +10 -12
  52. manim/mobject/table.py +0 -1
  53. manim/mobject/text/code_mobject.py +2 -0
  54. manim/mobject/text/numbers.py +2 -0
  55. manim/mobject/text/tex_mobject.py +1 -1
  56. manim/mobject/text/text_mobject.py +43 -6
  57. manim/mobject/three_d/three_d_utils.py +4 -4
  58. manim/mobject/three_d/three_dimensions.py +4 -4
  59. manim/mobject/types/image_mobject.py +5 -1
  60. manim/mobject/types/point_cloud_mobject.py +2 -0
  61. manim/mobject/types/vectorized_mobject.py +124 -29
  62. manim/mobject/value_tracker.py +3 -3
  63. manim/mobject/vector_field.py +3 -1
  64. manim/plugins/__init__.py +15 -1
  65. manim/plugins/plugins_flags.py +11 -5
  66. manim/renderer/cairo_renderer.py +12 -2
  67. manim/renderer/opengl_renderer.py +2 -3
  68. manim/renderer/opengl_renderer_window.py +2 -0
  69. manim/renderer/shader_wrapper.py +2 -0
  70. manim/renderer/vectorized_mobject_rendering.py +5 -0
  71. manim/scene/scene.py +22 -6
  72. manim/scene/scene_file_writer.py +3 -1
  73. manim/scene/section.py +2 -0
  74. manim/scene/three_d_scene.py +5 -6
  75. manim/scene/vector_space_scene.py +21 -5
  76. manim/typing.py +567 -67
  77. manim/utils/bezier.py +9 -18
  78. manim/utils/caching.py +2 -0
  79. manim/utils/color/BS381.py +1 -0
  80. manim/utils/color/XKCD.py +1 -0
  81. manim/utils/color/core.py +31 -13
  82. manim/utils/commands.py +8 -1
  83. manim/utils/debug.py +0 -1
  84. manim/utils/deprecation.py +3 -2
  85. manim/utils/docbuild/__init__.py +17 -0
  86. manim/utils/docbuild/autoaliasattr_directive.py +197 -0
  87. manim/utils/docbuild/autocolor_directive.py +9 -4
  88. manim/utils/docbuild/manim_directive.py +18 -9
  89. manim/utils/docbuild/module_parsing.py +198 -0
  90. manim/utils/exceptions.py +6 -0
  91. manim/utils/family.py +2 -0
  92. manim/utils/family_ops.py +5 -0
  93. manim/utils/file_ops.py +6 -2
  94. manim/utils/hashing.py +2 -0
  95. manim/utils/ipython_magic.py +2 -0
  96. manim/utils/module_ops.py +2 -0
  97. manim/utils/opengl.py +14 -0
  98. manim/utils/parameter_parsing.py +31 -0
  99. manim/utils/paths.py +12 -20
  100. manim/utils/rate_functions.py +6 -8
  101. manim/utils/space_ops.py +81 -36
  102. manim/utils/testing/__init__.py +17 -0
  103. manim/utils/testing/frames_comparison.py +7 -5
  104. manim/utils/tex.py +124 -196
  105. manim/utils/tex_file_writing.py +2 -0
  106. manim/utils/tex_templates.py +1 -0
  107. {manim-0.18.0.dist-info → manim-0.18.1.dist-info}/LICENSE.community +1 -1
  108. {manim-0.18.0.dist-info → manim-0.18.1.dist-info}/METADATA +29 -35
  109. {manim-0.18.0.dist-info → manim-0.18.1.dist-info}/RECORD +112 -112
  110. {manim-0.18.0.dist-info → manim-0.18.1.dist-info}/WHEEL +1 -1
  111. manim/cli/new/__init__.py +0 -0
  112. manim/cli/new/group.py +0 -189
  113. manim/plugins/import_plugins.py +0 -43
  114. {manim-0.18.0.dist-info → manim-0.18.1.dist-info}/LICENSE +0 -0
  115. {manim-0.18.0.dist-info → manim-0.18.1.dist-info}/entry_points.txt +0 -0
manim/mobject/graph.py CHANGED
@@ -9,11 +9,18 @@ __all__ = [
9
9
 
10
10
  import itertools as it
11
11
  from copy import copy
12
- from typing import Hashable, Iterable
12
+ from typing import TYPE_CHECKING, Any, Hashable, Iterable, Literal, Protocol, cast
13
13
 
14
14
  import networkx as nx
15
15
  import numpy as np
16
16
 
17
+ if TYPE_CHECKING:
18
+ from typing_extensions import TypeAlias
19
+
20
+ from manim.typing import Point3D
21
+
22
+ NxGraph: TypeAlias = nx.classes.graph.Graph | nx.classes.digraph.DiGraph
23
+
17
24
  from manim.animation.composition import AnimationGroup
18
25
  from manim.animation.creation import Create, Uncreate
19
26
  from manim.mobject.geometry.arc import Dot, LabeledDot
@@ -26,88 +33,290 @@ from manim.mobject.types.vectorized_mobject import VMobject
26
33
  from manim.utils.color import BLACK
27
34
 
28
35
 
29
- def _determine_graph_layout(
30
- nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph,
31
- layout: str | dict = "spring",
32
- layout_scale: float = 2,
33
- layout_config: dict | None = None,
34
- partitions: list[list[Hashable]] | None = None,
35
- root_vertex: Hashable | None = None,
36
- ) -> dict:
37
- automatic_layouts = {
38
- "circular": nx.layout.circular_layout,
39
- "kamada_kawai": nx.layout.kamada_kawai_layout,
40
- "planar": nx.layout.planar_layout,
41
- "random": nx.layout.random_layout,
42
- "shell": nx.layout.shell_layout,
43
- "spectral": nx.layout.spectral_layout,
44
- "partite": nx.layout.multipartite_layout,
45
- "tree": _tree_layout,
46
- "spiral": nx.layout.spiral_layout,
47
- "spring": nx.layout.spring_layout,
48
- }
49
-
50
- custom_layouts = ["random", "partite", "tree"]
36
+ class LayoutFunction(Protocol):
37
+ """A protocol for automatic layout functions that compute a layout for a graph to be used in :meth:`~.Graph.change_layout`.
51
38
 
52
- if layout_config is None:
53
- layout_config = {}
39
+ .. note:: The layout function must be a pure function, i.e., it must not modify the graph passed to it.
54
40
 
55
- if isinstance(layout, dict):
56
- return layout
57
- elif layout in automatic_layouts and layout not in custom_layouts:
58
- auto_layout = automatic_layouts[layout](
59
- nx_graph, scale=layout_scale, **layout_config
60
- )
61
- # NetworkX returns a dictionary of 3D points if the dimension
62
- # is specified to be 3. Otherwise, it returns a dictionary of
63
- # 2D points, so adjusting is required.
64
- if layout_config.get("dim") == 3:
65
- return auto_layout
66
- else:
67
- return {k: np.append(v, [0]) for k, v in auto_layout.items()}
68
- elif layout == "tree":
69
- return _tree_layout(
70
- nx_graph, root_vertex=root_vertex, scale=layout_scale, **layout_config
71
- )
72
- elif layout == "partite":
73
- if partitions is None or len(partitions) == 0:
74
- raise ValueError(
75
- "The partite layout requires the 'partitions' parameter to contain the partition of the vertices",
76
- )
77
- partition_count = len(partitions)
78
- for i in range(partition_count):
79
- for v in partitions[i]:
80
- if nx_graph.nodes[v] is None:
81
- raise ValueError(
82
- "The partition must contain arrays of vertices in the graph",
83
- )
84
- nx_graph.nodes[v]["subset"] = i
85
- # Add missing vertices to their own side
86
- for v in nx_graph.nodes:
87
- if "subset" not in nx_graph.nodes[v]:
88
- nx_graph.nodes[v]["subset"] = partition_count
89
-
90
- auto_layout = automatic_layouts["partite"](
91
- nx_graph, scale=layout_scale, **layout_config
92
- )
93
- return {k: np.append(v, [0]) for k, v in auto_layout.items()}
94
- elif layout == "random":
95
- # the random layout places coordinates in [0, 1)
96
- # we need to rescale manually afterwards...
97
- auto_layout = automatic_layouts["random"](nx_graph, **layout_config)
98
- for k, v in auto_layout.items():
99
- auto_layout[k] = 2 * layout_scale * (v - np.array([0.5, 0.5]))
100
- return {k: np.append(v, [0]) for k, v in auto_layout.items()}
101
- else:
41
+ Examples
42
+ --------
43
+
44
+ Here is an example that arranges nodes in an n x m grid in sorted order.
45
+
46
+ .. manim:: CustomLayoutExample
47
+ :save_last_frame:
48
+
49
+ class CustomLayoutExample(Scene):
50
+ def construct(self):
51
+ import numpy as np
52
+ import networkx as nx
53
+
54
+ # create custom layout
55
+ def custom_layout(
56
+ graph: nx.Graph,
57
+ scale: float | tuple[float, float, float] = 2,
58
+ n: int | None = None,
59
+ *args: Any,
60
+ **kwargs: Any,
61
+ ):
62
+ nodes = sorted(list(graph))
63
+ height = len(nodes) // n
64
+ return {
65
+ node: (scale * np.array([
66
+ (i % n) - (n-1)/2,
67
+ -(i // n) + height/2,
68
+ 0
69
+ ])) for i, node in enumerate(graph)
70
+ }
71
+
72
+ # draw graph
73
+ n = 4
74
+ graph = Graph(
75
+ [i for i in range(4 * 2 - 1)],
76
+ [(0, 1), (0, 4), (1, 2), (1, 5), (2, 3), (2, 6), (4, 5), (5, 6)],
77
+ labels=True,
78
+ layout=custom_layout,
79
+ layout_config={'n': n}
80
+ )
81
+ self.add(graph)
82
+
83
+ Several automatic layouts are provided by manim, and can be used by passing their name as the ``layout`` parameter to :meth:`~.Graph.change_layout`.
84
+ Alternatively, a custom layout function can be passed to :meth:`~.Graph.change_layout` as the ``layout`` parameter. Such a function must adhere to the :class:`~.LayoutFunction` protocol.
85
+
86
+ The :class:`~.LayoutFunction` s provided by manim are illustrated below:
87
+
88
+ - Circular Layout: places the vertices on a circle
89
+
90
+ .. manim:: CircularLayout
91
+ :save_last_frame:
92
+
93
+ class CircularLayout(Scene):
94
+ def construct(self):
95
+ graph = Graph(
96
+ [1, 2, 3, 4, 5, 6],
97
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)],
98
+ layout="circular",
99
+ labels=True
100
+ )
101
+ self.add(graph)
102
+
103
+ - Kamada Kawai Layout: tries to place the vertices such that the given distances between them are respected
104
+
105
+ .. manim:: KamadaKawaiLayout
106
+ :save_last_frame:
107
+
108
+ class KamadaKawaiLayout(Scene):
109
+ def construct(self):
110
+ from collections import defaultdict
111
+ distances: dict[int, dict[int, float]] = defaultdict(dict)
112
+
113
+ # set desired distances
114
+ distances[1][2] = 1 # distance between vertices 1 and 2 is 1
115
+ distances[2][3] = 1 # distance between vertices 2 and 3 is 1
116
+ distances[3][4] = 2 # etc
117
+ distances[4][5] = 3
118
+ distances[5][6] = 5
119
+ distances[6][1] = 8
120
+
121
+ graph = Graph(
122
+ [1, 2, 3, 4, 5, 6],
123
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1)],
124
+ layout="kamada_kawai",
125
+ layout_config={"dist": distances},
126
+ layout_scale=4,
127
+ labels=True
128
+ )
129
+ self.add(graph)
130
+
131
+ - Partite Layout: places vertices into distinct partitions
132
+
133
+ .. manim:: PartiteLayout
134
+ :save_last_frame:
135
+
136
+ class PartiteLayout(Scene):
137
+ def construct(self):
138
+ graph = Graph(
139
+ [1, 2, 3, 4, 5, 6],
140
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)],
141
+ layout="partite",
142
+ layout_config={"partitions": [[1,2],[3,4],[5,6]]},
143
+ labels=True
144
+ )
145
+ self.add(graph)
146
+
147
+ - Planar Layout: places vertices such that edges do not cross
148
+
149
+ .. manim:: PlanarLayout
150
+ :save_last_frame:
151
+
152
+ class PlanarLayout(Scene):
153
+ def construct(self):
154
+ graph = Graph(
155
+ [1, 2, 3, 4, 5, 6],
156
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)],
157
+ layout="planar",
158
+ layout_scale=4,
159
+ labels=True
160
+ )
161
+ self.add(graph)
162
+
163
+ - Random Layout: randomly places vertices
164
+
165
+ .. manim:: RandomLayout
166
+ :save_last_frame:
167
+
168
+ class RandomLayout(Scene):
169
+ def construct(self):
170
+ graph = Graph(
171
+ [1, 2, 3, 4, 5, 6],
172
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)],
173
+ layout="random",
174
+ labels=True
175
+ )
176
+ self.add(graph)
177
+
178
+ - Shell Layout: places vertices in concentric circles
179
+
180
+ .. manim:: ShellLayout
181
+ :save_last_frame:
182
+
183
+ class ShellLayout(Scene):
184
+ def construct(self):
185
+ nlist = [[1, 2, 3], [4, 5, 6, 7, 8, 9]]
186
+ graph = Graph(
187
+ [1, 2, 3, 4, 5, 6, 7, 8, 9],
188
+ [(1, 2), (2, 3), (3, 1), (4, 1), (4, 2), (5, 2), (6, 2), (6, 3), (7, 3), (8, 3), (8, 1), (9, 1)],
189
+ layout="shell",
190
+ layout_config={"nlist": nlist},
191
+ labels=True
192
+ )
193
+ self.add(graph)
194
+
195
+ - Spectral Layout: places vertices using the eigenvectors of the graph Laplacian (clusters nodes which are an approximation of the ratio cut)
196
+
197
+ .. manim:: SpectralLayout
198
+ :save_last_frame:
199
+
200
+ class SpectralLayout(Scene):
201
+ def construct(self):
202
+ graph = Graph(
203
+ [1, 2, 3, 4, 5, 6],
204
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)],
205
+ layout="spectral",
206
+ labels=True
207
+ )
208
+ self.add(graph)
209
+
210
+ - Sprial Layout: places vertices in a spiraling pattern
211
+
212
+ .. manim:: SpiralLayout
213
+ :save_last_frame:
214
+
215
+ class SpiralLayout(Scene):
216
+ def construct(self):
217
+ graph = Graph(
218
+ [1, 2, 3, 4, 5, 6],
219
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)],
220
+ layout="spiral",
221
+ labels=True
222
+ )
223
+ self.add(graph)
224
+
225
+ - Spring Layout: places nodes according to the Fruchterman-Reingold force-directed algorithm (attempts to minimize edge length while maximizing node separation)
226
+
227
+ .. manim:: SpringLayout
228
+ :save_last_frame:
229
+
230
+ class SpringLayout(Scene):
231
+ def construct(self):
232
+ graph = Graph(
233
+ [1, 2, 3, 4, 5, 6],
234
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)],
235
+ layout="spring",
236
+ labels=True
237
+ )
238
+ self.add(graph)
239
+
240
+ - Tree Layout: places vertices into a tree with a root node and branches (can only be used with legal trees)
241
+
242
+ .. manim:: TreeLayout
243
+ :save_last_frame:
244
+
245
+ class TreeLayout(Scene):
246
+ def construct(self):
247
+ graph = Graph(
248
+ [1, 2, 3, 4, 5, 6, 7],
249
+ [(1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)],
250
+ layout="tree",
251
+ layout_config={"root_vertex": 1},
252
+ labels=True
253
+ )
254
+ self.add(graph)
255
+
256
+ """
257
+
258
+ def __call__(
259
+ self,
260
+ graph: NxGraph,
261
+ scale: float | tuple[float, float, float] = 2,
262
+ *args: Any,
263
+ **kwargs: Any,
264
+ ) -> dict[Hashable, Point3D]:
265
+ """Given a graph and a scale, return a dictionary of coordinates.
266
+
267
+ Parameters
268
+ ----------
269
+ graph : NxGraph
270
+ The underlying NetworkX graph to be laid out. DO NOT MODIFY.
271
+ scale : float | tuple[float, float, float], optional
272
+ Either a single float value, or a tuple of three float values specifying the scale along each axis.
273
+
274
+ Returns
275
+ -------
276
+ dict[Hashable, Point3D]
277
+ A dictionary mapping vertices to their positions.
278
+ """
279
+ ...
280
+
281
+
282
+ def _partite_layout(
283
+ nx_graph: NxGraph,
284
+ scale: float = 2,
285
+ partitions: list[list[Hashable]] | None = None,
286
+ **kwargs: Any,
287
+ ) -> dict[Hashable, Point3D]:
288
+ if partitions is None or len(partitions) == 0:
102
289
  raise ValueError(
103
- f"The layout '{layout}' is neither a recognized automatic layout, "
104
- "nor a vertex placement dictionary.",
290
+ "The partite layout requires partitions parameter to contain the partition of the vertices",
105
291
  )
292
+ partition_count = len(partitions)
293
+ for i in range(partition_count):
294
+ for v in partitions[i]:
295
+ if nx_graph.nodes[v] is None:
296
+ raise ValueError(
297
+ "The partition must contain arrays of vertices in the graph",
298
+ )
299
+ nx_graph.nodes[v]["subset"] = i
300
+ # Add missing vertices to their own side
301
+ for v in nx_graph.nodes:
302
+ if "subset" not in nx_graph.nodes[v]:
303
+ nx_graph.nodes[v]["subset"] = partition_count
304
+
305
+ return nx.layout.multipartite_layout(nx_graph, scale=scale, **kwargs)
306
+
307
+
308
+ def _random_layout(nx_graph: NxGraph, scale: float = 2, **kwargs: Any):
309
+ # the random layout places coordinates in [0, 1)
310
+ # we need to rescale manually afterwards...
311
+ auto_layout = nx.layout.random_layout(nx_graph, **kwargs)
312
+ for k, v in auto_layout.items():
313
+ auto_layout[k] = 2 * scale * (v - np.array([0.5, 0.5]))
314
+ return {k: np.append(v, [0]) for k, v in auto_layout.items()}
106
315
 
107
316
 
108
317
  def _tree_layout(
109
- T: nx.classes.graph.Graph | nx.classes.digraph.DiGraph,
110
- root_vertex: Hashable | None,
318
+ T: NxGraph,
319
+ root_vertex: Hashable | None = None,
111
320
  scale: float | tuple | None = 2,
112
321
  vertex_spacing: tuple | None = None,
113
322
  orientation: str = "down",
@@ -212,6 +421,68 @@ def _tree_layout(
212
421
  return {v: (np.array([x, y, 0]) - center) * sf for v, (x, y) in pos.items()}
213
422
 
214
423
 
424
+ LayoutName = Literal[
425
+ "circular",
426
+ "kamada_kawai",
427
+ "partite",
428
+ "planar",
429
+ "random",
430
+ "shell",
431
+ "spectral",
432
+ "spiral",
433
+ "spring",
434
+ "tree",
435
+ ]
436
+
437
+ _layouts: dict[LayoutName, LayoutFunction] = {
438
+ "circular": cast(LayoutFunction, nx.layout.circular_layout),
439
+ "kamada_kawai": cast(LayoutFunction, nx.layout.kamada_kawai_layout),
440
+ "partite": cast(LayoutFunction, _partite_layout),
441
+ "planar": cast(LayoutFunction, nx.layout.planar_layout),
442
+ "random": cast(LayoutFunction, _random_layout),
443
+ "shell": cast(LayoutFunction, nx.layout.shell_layout),
444
+ "spectral": cast(LayoutFunction, nx.layout.spectral_layout),
445
+ "spiral": cast(LayoutFunction, nx.layout.spiral_layout),
446
+ "spring": cast(LayoutFunction, nx.layout.spring_layout),
447
+ "tree": cast(LayoutFunction, _tree_layout),
448
+ }
449
+
450
+
451
+ def _determine_graph_layout(
452
+ nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph,
453
+ layout: LayoutName | dict[Hashable, Point3D] | LayoutFunction = "spring",
454
+ layout_scale: float | tuple[float, float, float] = 2,
455
+ layout_config: dict[str, Any] | None = None,
456
+ ) -> dict[Hashable, Point3D]:
457
+ if layout_config is None:
458
+ layout_config = {}
459
+
460
+ if isinstance(layout, dict):
461
+ return layout
462
+ elif layout in _layouts:
463
+ auto_layout = _layouts[layout](nx_graph, scale=layout_scale, **layout_config)
464
+ # NetworkX returns a dictionary of 3D points if the dimension
465
+ # is specified to be 3. Otherwise, it returns a dictionary of
466
+ # 2D points, so adjusting is required.
467
+ if (
468
+ layout_config.get("dim") == 3
469
+ or auto_layout[next(auto_layout.__iter__())].shape[0] == 3
470
+ ):
471
+ return auto_layout
472
+ else:
473
+ return {k: np.append(v, [0]) for k, v in auto_layout.items()}
474
+ else:
475
+ try:
476
+ return cast(LayoutFunction, layout)(
477
+ nx_graph, scale=layout_scale, **layout_config
478
+ )
479
+ except TypeError as e:
480
+ raise ValueError(
481
+ f"The layout '{layout}' is neither a recognized layout, a layout function,"
482
+ "nor a vertex placement dictionary.",
483
+ )
484
+
485
+
215
486
  class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
216
487
  """Abstract base class for graphs (that is, a collection of vertices
217
488
  connected with edges).
@@ -254,14 +525,14 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
254
525
  layout
255
526
  Either one of ``"spring"`` (the default), ``"circular"``, ``"kamada_kawai"``,
256
527
  ``"planar"``, ``"random"``, ``"shell"``, ``"spectral"``, ``"spiral"``, ``"tree"``, and ``"partite"``
257
- for automatic vertex positioning using ``networkx``
528
+ for automatic vertex positioning primarily using ``networkx``
258
529
  (see `their documentation <https://networkx.org/documentation/stable/reference/drawing.html#module-networkx.drawing.layout>`_
259
- for more details), or a dictionary specifying a coordinate (value)
260
- for each vertex (key) for manual positioning.
530
+ for more details), a dictionary specifying a coordinate (value)
531
+ for each vertex (key) for manual positioning, or a .:class:`~.LayoutFunction` with a user-defined automatic layout.
261
532
  layout_config
262
- Only for automatically generated layouts. A dictionary whose entries
263
- are passed as keyword arguments to the automatic layout algorithm
264
- specified via ``layout`` of``networkx``.
533
+ Only for automatic layouts. A dictionary whose entries
534
+ are passed as keyword arguments to the named layout or automatic layout function
535
+ specified via ``layout``.
265
536
  The ``tree`` layout also accepts a special parameter ``vertex_spacing``
266
537
  passed as a keyword argument inside the ``layout_config`` dictionary.
267
538
  Passing a tuple ``(space_x, space_y)`` as this argument overrides
@@ -301,8 +572,8 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
301
572
  edges: list[tuple[Hashable, Hashable]],
302
573
  labels: bool | dict = False,
303
574
  label_fill_color: str = BLACK,
304
- layout: str | dict = "spring",
305
- layout_scale: float | tuple = 2,
575
+ layout: LayoutName | dict[Hashable, Point3D] | LayoutFunction = "spring",
576
+ layout_scale: float | tuple[float, float, float] = 2,
306
577
  layout_config: dict | None = None,
307
578
  vertex_type: type[Mobject] = Dot,
308
579
  vertex_config: dict | None = None,
@@ -319,15 +590,6 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
319
590
  nx_graph.add_edges_from(edges)
320
591
  self._graph = nx_graph
321
592
 
322
- self._layout = _determine_graph_layout(
323
- nx_graph,
324
- layout=layout,
325
- layout_scale=layout_scale,
326
- layout_config=layout_config,
327
- partitions=partitions,
328
- root_vertex=root_vertex,
329
- )
330
-
331
593
  if isinstance(labels, dict):
332
594
  self._labels = labels
333
595
  elif isinstance(labels, bool):
@@ -361,8 +623,14 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
361
623
 
362
624
  self.vertices = {v: vertex_type(**self._vertex_config[v]) for v in vertices}
363
625
  self.vertices.update(vertex_mobjects)
364
- for v in self.vertices:
365
- self[v].move_to(self._layout[v])
626
+
627
+ self.change_layout(
628
+ layout=layout,
629
+ layout_scale=layout_scale,
630
+ layout_config=layout_config,
631
+ partitions=partitions,
632
+ root_vertex=root_vertex,
633
+ )
366
634
 
367
635
  # build edge_config
368
636
  if edge_config is None:
@@ -399,7 +667,7 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
399
667
  self.add_updater(self.update_edges)
400
668
 
401
669
  @staticmethod
402
- def _empty_networkx_graph():
670
+ def _empty_networkx_graph() -> nx.classes.graph.Graph:
403
671
  """Return an empty networkx graph for the given graph type."""
404
672
  raise NotImplementedError("To be implemented in concrete subclasses")
405
673
 
@@ -415,13 +683,13 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
415
683
  def _create_vertex(
416
684
  self,
417
685
  vertex: Hashable,
418
- position: np.ndarray | None = None,
686
+ position: Point3D | None = None,
419
687
  label: bool = False,
420
688
  label_fill_color: str = BLACK,
421
689
  vertex_type: type[Mobject] = Dot,
422
690
  vertex_config: dict | None = None,
423
691
  vertex_mobject: dict | None = None,
424
- ) -> tuple[Hashable, np.ndarray, dict, Mobject]:
692
+ ) -> tuple[Hashable, Point3D, dict, Mobject]:
425
693
  if position is None:
426
694
  position = self.get_center()
427
695
 
@@ -459,7 +727,7 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
459
727
  def _add_created_vertex(
460
728
  self,
461
729
  vertex: Hashable,
462
- position: np.ndarray,
730
+ position: Point3D,
463
731
  vertex_config: dict,
464
732
  vertex_mobject: Mobject,
465
733
  ) -> Mobject:
@@ -485,7 +753,7 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
485
753
  def _add_vertex(
486
754
  self,
487
755
  vertex: Hashable,
488
- position: np.ndarray | None = None,
756
+ position: Point3D | None = None,
489
757
  label: bool = False,
490
758
  label_fill_color: str = BLACK,
491
759
  vertex_type: type[Mobject] = Dot,
@@ -540,7 +808,7 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
540
808
  vertex_type: type[Mobject] = Dot,
541
809
  vertex_config: dict | None = None,
542
810
  vertex_mobjects: dict | None = None,
543
- ) -> Iterable[tuple[Hashable, np.ndarray, dict, Mobject]]:
811
+ ) -> Iterable[tuple[Hashable, Point3D, dict, Mobject]]:
544
812
  if positions is None:
545
813
  positions = {}
546
814
  if vertex_mobjects is None:
@@ -944,9 +1212,9 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
944
1212
 
945
1213
  def change_layout(
946
1214
  self,
947
- layout: str | dict = "spring",
948
- layout_scale: float = 2,
949
- layout_config: dict | None = None,
1215
+ layout: LayoutName | dict[Hashable, Point3D] | LayoutFunction = "spring",
1216
+ layout_scale: float | tuple[float, float, float] = 2,
1217
+ layout_config: dict[str, Any] | None = None,
950
1218
  partitions: list[list[Hashable]] | None = None,
951
1219
  root_vertex: Hashable | None = None,
952
1220
  ) -> Graph:
@@ -970,14 +1238,19 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
970
1238
  self.play(G.animate.change_layout("circular"))
971
1239
  self.wait()
972
1240
  """
1241
+ layout_config = {} if layout_config is None else layout_config
1242
+ if partitions is not None and "partitions" not in layout_config:
1243
+ layout_config["partitions"] = partitions
1244
+ if root_vertex is not None and "root_vertex" not in layout_config:
1245
+ layout_config["root_vertex"] = root_vertex
1246
+
973
1247
  self._layout = _determine_graph_layout(
974
1248
  self._graph,
975
1249
  layout=layout,
976
1250
  layout_scale=layout_scale,
977
1251
  layout_config=layout_config,
978
- partitions=partitions,
979
- root_vertex=root_vertex,
980
1252
  )
1253
+
981
1254
  for v in self.vertices:
982
1255
  self[v].move_to(self._layout[v])
983
1256
  return self