manim 0.18.0.post0__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of manim might be problematic. Click here for more details.

Files changed (146) hide show
  1. manim/__init__.py +3 -6
  2. manim/__main__.py +61 -20
  3. manim/_config/__init__.py +6 -3
  4. manim/_config/cli_colors.py +16 -8
  5. manim/_config/default.cfg +1 -3
  6. manim/_config/logger_utils.py +14 -8
  7. manim/_config/utils.py +651 -472
  8. manim/animation/animation.py +152 -5
  9. manim/animation/composition.py +80 -39
  10. manim/animation/creation.py +196 -14
  11. manim/animation/fading.py +5 -9
  12. manim/animation/indication.py +103 -47
  13. manim/animation/movement.py +22 -5
  14. manim/animation/rotation.py +3 -2
  15. manim/animation/specialized.py +4 -6
  16. manim/animation/speedmodifier.py +10 -5
  17. manim/animation/transform.py +4 -5
  18. manim/animation/transform_matching_parts.py +1 -1
  19. manim/animation/updaters/mobject_update_utils.py +17 -14
  20. manim/camera/camera.py +15 -6
  21. manim/cli/__init__.py +17 -0
  22. manim/cli/cfg/group.py +70 -44
  23. manim/cli/checkhealth/checks.py +93 -75
  24. manim/cli/checkhealth/commands.py +14 -5
  25. manim/cli/default_group.py +157 -25
  26. manim/cli/init/commands.py +32 -24
  27. manim/cli/plugins/commands.py +16 -3
  28. manim/cli/render/commands.py +72 -60
  29. manim/cli/render/ease_of_access_options.py +4 -3
  30. manim/cli/render/global_options.py +51 -15
  31. manim/cli/render/output_options.py +6 -5
  32. manim/cli/render/render_options.py +97 -32
  33. manim/constants.py +65 -19
  34. manim/gui/gui.py +2 -0
  35. manim/mobject/frame.py +0 -1
  36. manim/mobject/geometry/arc.py +112 -78
  37. manim/mobject/geometry/boolean_ops.py +32 -25
  38. manim/mobject/geometry/labeled.py +300 -77
  39. manim/mobject/geometry/line.py +132 -64
  40. manim/mobject/geometry/polygram.py +126 -30
  41. manim/mobject/geometry/shape_matchers.py +35 -15
  42. manim/mobject/geometry/tips.py +38 -29
  43. manim/mobject/graph.py +414 -133
  44. manim/mobject/graphing/coordinate_systems.py +126 -64
  45. manim/mobject/graphing/functions.py +25 -15
  46. manim/mobject/graphing/number_line.py +24 -10
  47. manim/mobject/graphing/probability.py +2 -10
  48. manim/mobject/graphing/scale.py +6 -5
  49. manim/mobject/matrix.py +17 -19
  50. manim/mobject/mobject.py +314 -165
  51. manim/mobject/opengl/opengl_compatibility.py +2 -0
  52. manim/mobject/opengl/opengl_geometry.py +30 -9
  53. manim/mobject/opengl/opengl_image_mobject.py +2 -0
  54. manim/mobject/opengl/opengl_mobject.py +509 -343
  55. manim/mobject/opengl/opengl_point_cloud_mobject.py +5 -7
  56. manim/mobject/opengl/opengl_surface.py +3 -2
  57. manim/mobject/opengl/opengl_three_dimensions.py +2 -0
  58. manim/mobject/opengl/opengl_vectorized_mobject.py +46 -79
  59. manim/mobject/svg/brace.py +63 -13
  60. manim/mobject/svg/svg_mobject.py +4 -3
  61. manim/mobject/table.py +11 -13
  62. manim/mobject/text/code_mobject.py +186 -548
  63. manim/mobject/text/numbers.py +9 -7
  64. manim/mobject/text/tex_mobject.py +23 -14
  65. manim/mobject/text/text_mobject.py +70 -24
  66. manim/mobject/three_d/polyhedra.py +98 -1
  67. manim/mobject/three_d/three_d_utils.py +4 -4
  68. manim/mobject/three_d/three_dimensions.py +62 -34
  69. manim/mobject/types/image_mobject.py +42 -24
  70. manim/mobject/types/point_cloud_mobject.py +105 -67
  71. manim/mobject/types/vectorized_mobject.py +496 -228
  72. manim/mobject/value_tracker.py +5 -4
  73. manim/mobject/vector_field.py +5 -5
  74. manim/opengl/__init__.py +3 -3
  75. manim/plugins/__init__.py +14 -1
  76. manim/plugins/plugins_flags.py +14 -8
  77. manim/renderer/cairo_renderer.py +20 -10
  78. manim/renderer/opengl_renderer.py +21 -23
  79. manim/renderer/opengl_renderer_window.py +2 -0
  80. manim/renderer/shader.py +2 -3
  81. manim/renderer/shader_wrapper.py +5 -2
  82. manim/renderer/vectorized_mobject_rendering.py +5 -0
  83. manim/scene/moving_camera_scene.py +23 -0
  84. manim/scene/scene.py +90 -43
  85. manim/scene/scene_file_writer.py +316 -165
  86. manim/scene/section.py +17 -15
  87. manim/scene/three_d_scene.py +13 -21
  88. manim/scene/vector_space_scene.py +22 -9
  89. manim/typing.py +830 -70
  90. manim/utils/bezier.py +1667 -399
  91. manim/utils/caching.py +13 -5
  92. manim/utils/color/AS2700.py +2 -0
  93. manim/utils/color/BS381.py +3 -0
  94. manim/utils/color/DVIPSNAMES.py +96 -0
  95. manim/utils/color/SVGNAMES.py +179 -0
  96. manim/utils/color/X11.py +3 -0
  97. manim/utils/color/XKCD.py +3 -0
  98. manim/utils/color/__init__.py +8 -5
  99. manim/utils/color/core.py +844 -309
  100. manim/utils/color/manim_colors.py +7 -9
  101. manim/utils/commands.py +48 -20
  102. manim/utils/config_ops.py +18 -13
  103. manim/utils/debug.py +8 -7
  104. manim/utils/deprecation.py +90 -40
  105. manim/utils/docbuild/__init__.py +17 -0
  106. manim/utils/docbuild/autoaliasattr_directive.py +234 -0
  107. manim/utils/docbuild/autocolor_directive.py +21 -17
  108. manim/utils/docbuild/manim_directive.py +50 -35
  109. manim/utils/docbuild/module_parsing.py +245 -0
  110. manim/utils/exceptions.py +6 -0
  111. manim/utils/family.py +5 -3
  112. manim/utils/family_ops.py +17 -4
  113. manim/utils/file_ops.py +26 -16
  114. manim/utils/hashing.py +9 -7
  115. manim/utils/images.py +10 -4
  116. manim/utils/ipython_magic.py +14 -8
  117. manim/utils/iterables.py +161 -119
  118. manim/utils/module_ops.py +57 -19
  119. manim/utils/opengl.py +83 -24
  120. manim/utils/parameter_parsing.py +32 -0
  121. manim/utils/paths.py +21 -23
  122. manim/utils/polylabel.py +168 -0
  123. manim/utils/qhull.py +218 -0
  124. manim/utils/rate_functions.py +74 -39
  125. manim/utils/simple_functions.py +24 -15
  126. manim/utils/sounds.py +7 -1
  127. manim/utils/space_ops.py +125 -69
  128. manim/utils/testing/__init__.py +17 -0
  129. manim/utils/testing/_frames_testers.py +13 -8
  130. manim/utils/testing/_show_diff.py +5 -3
  131. manim/utils/testing/_test_class_makers.py +33 -18
  132. manim/utils/testing/frames_comparison.py +27 -19
  133. manim/utils/tex.py +127 -197
  134. manim/utils/tex_file_writing.py +47 -45
  135. manim/utils/tex_templates.py +2 -1
  136. manim/utils/unit.py +6 -5
  137. {manim-0.18.0.post0.dist-info → manim-0.19.0.dist-info}/LICENSE.community +1 -1
  138. {manim-0.18.0.post0.dist-info → manim-0.19.0.dist-info}/METADATA +40 -39
  139. manim-0.19.0.dist-info/RECORD +221 -0
  140. {manim-0.18.0.post0.dist-info → manim-0.19.0.dist-info}/WHEEL +1 -1
  141. manim/cli/new/__init__.py +0 -0
  142. manim/cli/new/group.py +0 -189
  143. manim/plugins/import_plugins.py +0 -43
  144. manim-0.18.0.post0.dist-info/RECORD +0 -217
  145. {manim-0.18.0.post0.dist-info → manim-0.19.0.dist-info}/LICENSE +0 -0
  146. {manim-0.18.0.post0.dist-info → manim-0.19.0.dist-info}/entry_points.txt +0 -0
manim/mobject/graph.py CHANGED
@@ -8,12 +8,21 @@ __all__ = [
8
8
  ]
9
9
 
10
10
  import itertools as it
11
+ from collections.abc import Hashable, Iterable, Sequence
11
12
  from copy import copy
12
- from typing import Hashable, Iterable
13
+ from typing import TYPE_CHECKING, Any, Literal, Protocol, cast
13
14
 
14
15
  import networkx as nx
15
16
  import numpy as np
16
17
 
18
+ if TYPE_CHECKING:
19
+ from typing_extensions import TypeAlias
20
+
21
+ from manim.scene.scene import Scene
22
+ from manim.typing import Point3D, Point3DLike
23
+
24
+ NxGraph: TypeAlias = nx.classes.graph.Graph | nx.classes.digraph.DiGraph
25
+
17
26
  from manim.animation.composition import AnimationGroup
18
27
  from manim.animation.creation import Create, Uncreate
19
28
  from manim.mobject.geometry.arc import Dot, LabeledDot
@@ -26,88 +35,290 @@ from manim.mobject.types.vectorized_mobject import VMobject
26
35
  from manim.utils.color import BLACK
27
36
 
28
37
 
29
- def _determine_graph_layout(
30
- nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph,
31
- layout: str | dict = "spring",
32
- layout_scale: float = 2,
33
- layout_config: dict | None = None,
34
- partitions: list[list[Hashable]] | None = None,
35
- root_vertex: Hashable | None = None,
36
- ) -> dict:
37
- automatic_layouts = {
38
- "circular": nx.layout.circular_layout,
39
- "kamada_kawai": nx.layout.kamada_kawai_layout,
40
- "planar": nx.layout.planar_layout,
41
- "random": nx.layout.random_layout,
42
- "shell": nx.layout.shell_layout,
43
- "spectral": nx.layout.spectral_layout,
44
- "partite": nx.layout.multipartite_layout,
45
- "tree": _tree_layout,
46
- "spiral": nx.layout.spiral_layout,
47
- "spring": nx.layout.spring_layout,
48
- }
49
-
50
- custom_layouts = ["random", "partite", "tree"]
38
+ class LayoutFunction(Protocol):
39
+ """A protocol for automatic layout functions that compute a layout for a graph to be used in :meth:`~.Graph.change_layout`.
51
40
 
52
- if layout_config is None:
53
- layout_config = {}
41
+ .. note:: The layout function must be a pure function, i.e., it must not modify the graph passed to it.
54
42
 
55
- if isinstance(layout, dict):
56
- return layout
57
- elif layout in automatic_layouts and layout not in custom_layouts:
58
- auto_layout = automatic_layouts[layout](
59
- nx_graph, scale=layout_scale, **layout_config
60
- )
61
- # NetworkX returns a dictionary of 3D points if the dimension
62
- # is specified to be 3. Otherwise, it returns a dictionary of
63
- # 2D points, so adjusting is required.
64
- if layout_config.get("dim") == 3:
65
- return auto_layout
66
- else:
67
- return {k: np.append(v, [0]) for k, v in auto_layout.items()}
68
- elif layout == "tree":
69
- return _tree_layout(
70
- nx_graph, root_vertex=root_vertex, scale=layout_scale, **layout_config
71
- )
72
- elif layout == "partite":
73
- if partitions is None or len(partitions) == 0:
74
- raise ValueError(
75
- "The partite layout requires the 'partitions' parameter to contain the partition of the vertices",
76
- )
77
- partition_count = len(partitions)
78
- for i in range(partition_count):
79
- for v in partitions[i]:
80
- if nx_graph.nodes[v] is None:
81
- raise ValueError(
82
- "The partition must contain arrays of vertices in the graph",
83
- )
84
- nx_graph.nodes[v]["subset"] = i
85
- # Add missing vertices to their own side
86
- for v in nx_graph.nodes:
87
- if "subset" not in nx_graph.nodes[v]:
88
- nx_graph.nodes[v]["subset"] = partition_count
89
-
90
- auto_layout = automatic_layouts["partite"](
91
- nx_graph, scale=layout_scale, **layout_config
92
- )
93
- return {k: np.append(v, [0]) for k, v in auto_layout.items()}
94
- elif layout == "random":
95
- # the random layout places coordinates in [0, 1)
96
- # we need to rescale manually afterwards...
97
- auto_layout = automatic_layouts["random"](nx_graph, **layout_config)
98
- for k, v in auto_layout.items():
99
- auto_layout[k] = 2 * layout_scale * (v - np.array([0.5, 0.5]))
100
- return {k: np.append(v, [0]) for k, v in auto_layout.items()}
101
- else:
43
+ Examples
44
+ --------
45
+
46
+ Here is an example that arranges nodes in an n x m grid in sorted order.
47
+
48
+ .. manim:: CustomLayoutExample
49
+ :save_last_frame:
50
+
51
+ class CustomLayoutExample(Scene):
52
+ def construct(self):
53
+ import numpy as np
54
+ import networkx as nx
55
+
56
+ # create custom layout
57
+ def custom_layout(
58
+ graph: nx.Graph,
59
+ scale: float | tuple[float, float, float] = 2,
60
+ n: int | None = None,
61
+ *args: Any,
62
+ **kwargs: Any,
63
+ ):
64
+ nodes = sorted(list(graph))
65
+ height = len(nodes) // n
66
+ return {
67
+ node: (scale * np.array([
68
+ (i % n) - (n-1)/2,
69
+ -(i // n) + height/2,
70
+ 0
71
+ ])) for i, node in enumerate(graph)
72
+ }
73
+
74
+ # draw graph
75
+ n = 4
76
+ graph = Graph(
77
+ [i for i in range(4 * 2 - 1)],
78
+ [(0, 1), (0, 4), (1, 2), (1, 5), (2, 3), (2, 6), (4, 5), (5, 6)],
79
+ labels=True,
80
+ layout=custom_layout,
81
+ layout_config={'n': n}
82
+ )
83
+ self.add(graph)
84
+
85
+ Several automatic layouts are provided by manim, and can be used by passing their name as the ``layout`` parameter to :meth:`~.Graph.change_layout`.
86
+ Alternatively, a custom layout function can be passed to :meth:`~.Graph.change_layout` as the ``layout`` parameter. Such a function must adhere to the :class:`~.LayoutFunction` protocol.
87
+
88
+ The :class:`~.LayoutFunction` s provided by manim are illustrated below:
89
+
90
+ - Circular Layout: places the vertices on a circle
91
+
92
+ .. manim:: CircularLayout
93
+ :save_last_frame:
94
+
95
+ class CircularLayout(Scene):
96
+ def construct(self):
97
+ graph = Graph(
98
+ [1, 2, 3, 4, 5, 6],
99
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)],
100
+ layout="circular",
101
+ labels=True
102
+ )
103
+ self.add(graph)
104
+
105
+ - Kamada Kawai Layout: tries to place the vertices such that the given distances between them are respected
106
+
107
+ .. manim:: KamadaKawaiLayout
108
+ :save_last_frame:
109
+
110
+ class KamadaKawaiLayout(Scene):
111
+ def construct(self):
112
+ from collections import defaultdict
113
+ distances: dict[int, dict[int, float]] = defaultdict(dict)
114
+
115
+ # set desired distances
116
+ distances[1][2] = 1 # distance between vertices 1 and 2 is 1
117
+ distances[2][3] = 1 # distance between vertices 2 and 3 is 1
118
+ distances[3][4] = 2 # etc
119
+ distances[4][5] = 3
120
+ distances[5][6] = 5
121
+ distances[6][1] = 8
122
+
123
+ graph = Graph(
124
+ [1, 2, 3, 4, 5, 6],
125
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1)],
126
+ layout="kamada_kawai",
127
+ layout_config={"dist": distances},
128
+ layout_scale=4,
129
+ labels=True
130
+ )
131
+ self.add(graph)
132
+
133
+ - Partite Layout: places vertices into distinct partitions
134
+
135
+ .. manim:: PartiteLayout
136
+ :save_last_frame:
137
+
138
+ class PartiteLayout(Scene):
139
+ def construct(self):
140
+ graph = Graph(
141
+ [1, 2, 3, 4, 5, 6],
142
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)],
143
+ layout="partite",
144
+ layout_config={"partitions": [[1,2],[3,4],[5,6]]},
145
+ labels=True
146
+ )
147
+ self.add(graph)
148
+
149
+ - Planar Layout: places vertices such that edges do not cross
150
+
151
+ .. manim:: PlanarLayout
152
+ :save_last_frame:
153
+
154
+ class PlanarLayout(Scene):
155
+ def construct(self):
156
+ graph = Graph(
157
+ [1, 2, 3, 4, 5, 6],
158
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)],
159
+ layout="planar",
160
+ layout_scale=4,
161
+ labels=True
162
+ )
163
+ self.add(graph)
164
+
165
+ - Random Layout: randomly places vertices
166
+
167
+ .. manim:: RandomLayout
168
+ :save_last_frame:
169
+
170
+ class RandomLayout(Scene):
171
+ def construct(self):
172
+ graph = Graph(
173
+ [1, 2, 3, 4, 5, 6],
174
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)],
175
+ layout="random",
176
+ labels=True
177
+ )
178
+ self.add(graph)
179
+
180
+ - Shell Layout: places vertices in concentric circles
181
+
182
+ .. manim:: ShellLayout
183
+ :save_last_frame:
184
+
185
+ class ShellLayout(Scene):
186
+ def construct(self):
187
+ nlist = [[1, 2, 3], [4, 5, 6, 7, 8, 9]]
188
+ graph = Graph(
189
+ [1, 2, 3, 4, 5, 6, 7, 8, 9],
190
+ [(1, 2), (2, 3), (3, 1), (4, 1), (4, 2), (5, 2), (6, 2), (6, 3), (7, 3), (8, 3), (8, 1), (9, 1)],
191
+ layout="shell",
192
+ layout_config={"nlist": nlist},
193
+ labels=True
194
+ )
195
+ self.add(graph)
196
+
197
+ - Spectral Layout: places vertices using the eigenvectors of the graph Laplacian (clusters nodes which are an approximation of the ratio cut)
198
+
199
+ .. manim:: SpectralLayout
200
+ :save_last_frame:
201
+
202
+ class SpectralLayout(Scene):
203
+ def construct(self):
204
+ graph = Graph(
205
+ [1, 2, 3, 4, 5, 6],
206
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)],
207
+ layout="spectral",
208
+ labels=True
209
+ )
210
+ self.add(graph)
211
+
212
+ - Sprial Layout: places vertices in a spiraling pattern
213
+
214
+ .. manim:: SpiralLayout
215
+ :save_last_frame:
216
+
217
+ class SpiralLayout(Scene):
218
+ def construct(self):
219
+ graph = Graph(
220
+ [1, 2, 3, 4, 5, 6],
221
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)],
222
+ layout="spiral",
223
+ labels=True
224
+ )
225
+ self.add(graph)
226
+
227
+ - Spring Layout: places nodes according to the Fruchterman-Reingold force-directed algorithm (attempts to minimize edge length while maximizing node separation)
228
+
229
+ .. manim:: SpringLayout
230
+ :save_last_frame:
231
+
232
+ class SpringLayout(Scene):
233
+ def construct(self):
234
+ graph = Graph(
235
+ [1, 2, 3, 4, 5, 6],
236
+ [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)],
237
+ layout="spring",
238
+ labels=True
239
+ )
240
+ self.add(graph)
241
+
242
+ - Tree Layout: places vertices into a tree with a root node and branches (can only be used with legal trees)
243
+
244
+ .. manim:: TreeLayout
245
+ :save_last_frame:
246
+
247
+ class TreeLayout(Scene):
248
+ def construct(self):
249
+ graph = Graph(
250
+ [1, 2, 3, 4, 5, 6, 7],
251
+ [(1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)],
252
+ layout="tree",
253
+ layout_config={"root_vertex": 1},
254
+ labels=True
255
+ )
256
+ self.add(graph)
257
+
258
+ """
259
+
260
+ def __call__(
261
+ self,
262
+ graph: NxGraph,
263
+ scale: float | tuple[float, float, float] = 2,
264
+ *args: Any,
265
+ **kwargs: Any,
266
+ ) -> dict[Hashable, Point3D]:
267
+ """Given a graph and a scale, return a dictionary of coordinates.
268
+
269
+ Parameters
270
+ ----------
271
+ graph
272
+ The underlying NetworkX graph to be laid out. DO NOT MODIFY.
273
+ scale
274
+ Either a single float value, or a tuple of three float values specifying the scale along each axis.
275
+
276
+ Returns
277
+ -------
278
+ dict[Hashable, Point3D]
279
+ A dictionary mapping vertices to their positions.
280
+ """
281
+ ...
282
+
283
+
284
+ def _partite_layout(
285
+ nx_graph: NxGraph,
286
+ scale: float = 2,
287
+ partitions: Sequence[Sequence[Hashable]] | None = None,
288
+ **kwargs: Any,
289
+ ) -> dict[Hashable, Point3D]:
290
+ if partitions is None or len(partitions) == 0:
102
291
  raise ValueError(
103
- f"The layout '{layout}' is neither a recognized automatic layout, "
104
- "nor a vertex placement dictionary.",
292
+ "The partite layout requires partitions parameter to contain the partition of the vertices",
105
293
  )
294
+ partition_count = len(partitions)
295
+ for i in range(partition_count):
296
+ for v in partitions[i]:
297
+ if nx_graph.nodes[v] is None:
298
+ raise ValueError(
299
+ "The partition must contain arrays of vertices in the graph",
300
+ )
301
+ nx_graph.nodes[v]["subset"] = i
302
+ # Add missing vertices to their own side
303
+ for v in nx_graph.nodes:
304
+ if "subset" not in nx_graph.nodes[v]:
305
+ nx_graph.nodes[v]["subset"] = partition_count
306
+
307
+ return nx.layout.multipartite_layout(nx_graph, scale=scale, **kwargs)
308
+
309
+
310
+ def _random_layout(nx_graph: NxGraph, scale: float = 2, **kwargs: Any):
311
+ # the random layout places coordinates in [0, 1)
312
+ # we need to rescale manually afterwards...
313
+ auto_layout = nx.layout.random_layout(nx_graph, **kwargs)
314
+ for k, v in auto_layout.items():
315
+ auto_layout[k] = 2 * scale * (v - np.array([0.5, 0.5]))
316
+ return {k: np.append(v, [0]) for k, v in auto_layout.items()}
106
317
 
107
318
 
108
319
  def _tree_layout(
109
- T: nx.classes.graph.Graph | nx.classes.digraph.DiGraph,
110
- root_vertex: Hashable | None,
320
+ T: NxGraph,
321
+ root_vertex: Hashable | None = None,
111
322
  scale: float | tuple | None = 2,
112
323
  vertex_spacing: tuple | None = None,
113
324
  orientation: str = "down",
@@ -127,10 +338,7 @@ def _tree_layout(
127
338
  parent = {u: root_vertex for u in children[root_vertex]}
128
339
  pos = {}
129
340
  obstruction = [0.0] * len(T)
130
- if orientation == "down":
131
- o = -1
132
- else:
133
- o = 1
341
+ o = -1 if orientation == "down" else 1
134
342
 
135
343
  def slide(v, dx):
136
344
  """
@@ -193,15 +401,9 @@ def _tree_layout(
193
401
  if isinstance(scale, (float, int)) and (width > 0 or height > 0):
194
402
  sf = 2 * scale / max(width, height)
195
403
  elif isinstance(scale, tuple):
196
- if scale[0] is not None and width > 0:
197
- sw = 2 * scale[0] / width
198
- else:
199
- sw = 1
404
+ sw = 2 * scale[0] / width if scale[0] is not None and width > 0 else 1
200
405
 
201
- if scale[1] is not None and height > 0:
202
- sh = 2 * scale[1] / height
203
- else:
204
- sh = 1
406
+ sh = 2 * scale[1] / height if scale[1] is not None and height > 0 else 1
205
407
 
206
408
  sf = np.array([sw, sh, 0])
207
409
  else:
@@ -212,6 +414,68 @@ def _tree_layout(
212
414
  return {v: (np.array([x, y, 0]) - center) * sf for v, (x, y) in pos.items()}
213
415
 
214
416
 
417
+ LayoutName = Literal[
418
+ "circular",
419
+ "kamada_kawai",
420
+ "partite",
421
+ "planar",
422
+ "random",
423
+ "shell",
424
+ "spectral",
425
+ "spiral",
426
+ "spring",
427
+ "tree",
428
+ ]
429
+
430
+ _layouts: dict[LayoutName, LayoutFunction] = {
431
+ "circular": cast(LayoutFunction, nx.layout.circular_layout),
432
+ "kamada_kawai": cast(LayoutFunction, nx.layout.kamada_kawai_layout),
433
+ "partite": cast(LayoutFunction, _partite_layout),
434
+ "planar": cast(LayoutFunction, nx.layout.planar_layout),
435
+ "random": cast(LayoutFunction, _random_layout),
436
+ "shell": cast(LayoutFunction, nx.layout.shell_layout),
437
+ "spectral": cast(LayoutFunction, nx.layout.spectral_layout),
438
+ "spiral": cast(LayoutFunction, nx.layout.spiral_layout),
439
+ "spring": cast(LayoutFunction, nx.layout.spring_layout),
440
+ "tree": cast(LayoutFunction, _tree_layout),
441
+ }
442
+
443
+
444
+ def _determine_graph_layout(
445
+ nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph,
446
+ layout: LayoutName | dict[Hashable, Point3DLike] | LayoutFunction = "spring",
447
+ layout_scale: float | tuple[float, float, float] = 2,
448
+ layout_config: dict[str, Any] | None = None,
449
+ ) -> dict[Hashable, Point3DLike]:
450
+ if layout_config is None:
451
+ layout_config = {}
452
+
453
+ if isinstance(layout, dict):
454
+ return layout
455
+ elif layout in _layouts:
456
+ auto_layout = _layouts[layout](nx_graph, scale=layout_scale, **layout_config)
457
+ # NetworkX returns a dictionary of 3D points if the dimension
458
+ # is specified to be 3. Otherwise, it returns a dictionary of
459
+ # 2D points, so adjusting is required.
460
+ if (
461
+ layout_config.get("dim") == 3
462
+ or auto_layout[next(auto_layout.__iter__())].shape[0] == 3
463
+ ):
464
+ return auto_layout
465
+ else:
466
+ return {k: np.append(v, [0]) for k, v in auto_layout.items()}
467
+ else:
468
+ try:
469
+ return cast(LayoutFunction, layout)(
470
+ nx_graph, scale=layout_scale, **layout_config
471
+ )
472
+ except TypeError as e:
473
+ raise ValueError(
474
+ f"The layout '{layout}' is neither a recognized layout, a layout function,"
475
+ "nor a vertex placement dictionary.",
476
+ ) from e
477
+
478
+
215
479
  class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
216
480
  """Abstract base class for graphs (that is, a collection of vertices
217
481
  connected with edges).
@@ -254,14 +518,14 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
254
518
  layout
255
519
  Either one of ``"spring"`` (the default), ``"circular"``, ``"kamada_kawai"``,
256
520
  ``"planar"``, ``"random"``, ``"shell"``, ``"spectral"``, ``"spiral"``, ``"tree"``, and ``"partite"``
257
- for automatic vertex positioning using ``networkx``
521
+ for automatic vertex positioning primarily using ``networkx``
258
522
  (see `their documentation <https://networkx.org/documentation/stable/reference/drawing.html#module-networkx.drawing.layout>`_
259
- for more details), or a dictionary specifying a coordinate (value)
260
- for each vertex (key) for manual positioning.
523
+ for more details), a dictionary specifying a coordinate (value)
524
+ for each vertex (key) for manual positioning, or a .:class:`~.LayoutFunction` with a user-defined automatic layout.
261
525
  layout_config
262
- Only for automatically generated layouts. A dictionary whose entries
263
- are passed as keyword arguments to the automatic layout algorithm
264
- specified via ``layout`` of``networkx``.
526
+ Only for automatic layouts. A dictionary whose entries
527
+ are passed as keyword arguments to the named layout or automatic layout function
528
+ specified via ``layout``.
265
529
  The ``tree`` layout also accepts a special parameter ``vertex_spacing``
266
530
  passed as a keyword argument inside the ``layout_config`` dictionary.
267
531
  Passing a tuple ``(space_x, space_y)`` as this argument overrides
@@ -288,6 +552,7 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
288
552
  all other configuration options for a vertex.
289
553
  edge_type
290
554
  The mobject class used for displaying edges in the scene.
555
+ Must be a subclass of :class:`~.Line` for default updaters to work.
291
556
  edge_config
292
557
  Either a dictionary containing keyword arguments to be passed
293
558
  to the class specified via ``edge_type``, or a dictionary whose
@@ -297,18 +562,18 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
297
562
 
298
563
  def __init__(
299
564
  self,
300
- vertices: list[Hashable],
301
- edges: list[tuple[Hashable, Hashable]],
565
+ vertices: Sequence[Hashable],
566
+ edges: Sequence[tuple[Hashable, Hashable]],
302
567
  labels: bool | dict = False,
303
568
  label_fill_color: str = BLACK,
304
- layout: str | dict = "spring",
305
- layout_scale: float | tuple = 2,
569
+ layout: LayoutName | dict[Hashable, Point3DLike] | LayoutFunction = "spring",
570
+ layout_scale: float | tuple[float, float, float] = 2,
306
571
  layout_config: dict | None = None,
307
572
  vertex_type: type[Mobject] = Dot,
308
573
  vertex_config: dict | None = None,
309
574
  vertex_mobjects: dict | None = None,
310
575
  edge_type: type[Mobject] = Line,
311
- partitions: list[list[Hashable]] | None = None,
576
+ partitions: Sequence[Sequence[Hashable]] | None = None,
312
577
  root_vertex: Hashable | None = None,
313
578
  edge_config: dict | None = None,
314
579
  ) -> None:
@@ -319,15 +584,6 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
319
584
  nx_graph.add_edges_from(edges)
320
585
  self._graph = nx_graph
321
586
 
322
- self._layout = _determine_graph_layout(
323
- nx_graph,
324
- layout=layout,
325
- layout_scale=layout_scale,
326
- layout_config=layout_config,
327
- partitions=partitions,
328
- root_vertex=root_vertex,
329
- )
330
-
331
587
  if isinstance(labels, dict):
332
588
  self._labels = labels
333
589
  elif isinstance(labels, bool):
@@ -361,8 +617,14 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
361
617
 
362
618
  self.vertices = {v: vertex_type(**self._vertex_config[v]) for v in vertices}
363
619
  self.vertices.update(vertex_mobjects)
364
- for v in self.vertices:
365
- self[v].move_to(self._layout[v])
620
+
621
+ self.change_layout(
622
+ layout=layout,
623
+ layout_scale=layout_scale,
624
+ layout_config=layout_config,
625
+ partitions=partitions,
626
+ root_vertex=root_vertex,
627
+ )
366
628
 
367
629
  # build edge_config
368
630
  if edge_config is None:
@@ -399,7 +661,7 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
399
661
  self.add_updater(self.update_edges)
400
662
 
401
663
  @staticmethod
402
- def _empty_networkx_graph():
664
+ def _empty_networkx_graph() -> nx.classes.graph.Graph:
403
665
  """Return an empty networkx graph for the given graph type."""
404
666
  raise NotImplementedError("To be implemented in concrete subclasses")
405
667
 
@@ -415,15 +677,16 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
415
677
  def _create_vertex(
416
678
  self,
417
679
  vertex: Hashable,
418
- position: np.ndarray | None = None,
680
+ position: Point3DLike | None = None,
419
681
  label: bool = False,
420
682
  label_fill_color: str = BLACK,
421
683
  vertex_type: type[Mobject] = Dot,
422
684
  vertex_config: dict | None = None,
423
685
  vertex_mobject: dict | None = None,
424
- ) -> tuple[Hashable, np.ndarray, dict, Mobject]:
425
- if position is None:
426
- position = self.get_center()
686
+ ) -> tuple[Hashable, Point3D, dict, Mobject]:
687
+ np_position: Point3D = (
688
+ self.get_center() if position is None else np.asarray(position)
689
+ )
427
690
 
428
691
  if vertex_config is None:
429
692
  vertex_config = {}
@@ -452,14 +715,14 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
452
715
  if vertex_mobject is None:
453
716
  vertex_mobject = vertex_type(**vertex_config)
454
717
 
455
- vertex_mobject.move_to(position)
718
+ vertex_mobject.move_to(np_position)
456
719
 
457
- return (vertex, position, vertex_config, vertex_mobject)
720
+ return (vertex, np_position, vertex_config, vertex_mobject)
458
721
 
459
722
  def _add_created_vertex(
460
723
  self,
461
724
  vertex: Hashable,
462
- position: np.ndarray,
725
+ position: Point3DLike,
463
726
  vertex_config: dict,
464
727
  vertex_mobject: Mobject,
465
728
  ) -> Mobject:
@@ -485,7 +748,7 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
485
748
  def _add_vertex(
486
749
  self,
487
750
  vertex: Hashable,
488
- position: np.ndarray | None = None,
751
+ position: Point3DLike | None = None,
489
752
  label: bool = False,
490
753
  label_fill_color: str = BLACK,
491
754
  vertex_type: type[Mobject] = Dot,
@@ -540,7 +803,7 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
540
803
  vertex_type: type[Mobject] = Dot,
541
804
  vertex_config: dict | None = None,
542
805
  vertex_mobjects: dict | None = None,
543
- ) -> Iterable[tuple[Hashable, np.ndarray, dict, Mobject]]:
806
+ ) -> Iterable[tuple[Hashable, Point3D, dict, Mobject]]:
544
807
  if positions is None:
545
808
  positions = {}
546
809
  if vertex_mobjects is None:
@@ -555,7 +818,7 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
555
818
  labels = {v: labels for v in vertices}
556
819
  else:
557
820
  assert isinstance(labels, dict)
558
- base_labels = {v: False for v in vertices}
821
+ base_labels = dict.fromkeys(vertices, False)
559
822
  base_labels.update(labels)
560
823
  labels = base_labels
561
824
 
@@ -580,7 +843,7 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
580
843
  label_fill_color=label_fill_color,
581
844
  vertex_type=vertex_type,
582
845
  vertex_config=vertex_config[v],
583
- vertex_mobject=vertex_mobjects[v] if v in vertex_mobjects else None,
846
+ vertex_mobject=vertex_mobjects.get(v),
584
847
  )
585
848
  for v in vertices
586
849
  ]
@@ -944,9 +1207,9 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
944
1207
 
945
1208
  def change_layout(
946
1209
  self,
947
- layout: str | dict = "spring",
948
- layout_scale: float = 2,
949
- layout_config: dict | None = None,
1210
+ layout: LayoutName | dict[Hashable, Point3DLike] | LayoutFunction = "spring",
1211
+ layout_scale: float | tuple[float, float, float] = 2,
1212
+ layout_config: dict[str, Any] | None = None,
950
1213
  partitions: list[list[Hashable]] | None = None,
951
1214
  root_vertex: Hashable | None = None,
952
1215
  ) -> Graph:
@@ -970,14 +1233,19 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
970
1233
  self.play(G.animate.change_layout("circular"))
971
1234
  self.wait()
972
1235
  """
1236
+ layout_config = {} if layout_config is None else layout_config
1237
+ if partitions is not None and "partitions" not in layout_config:
1238
+ layout_config["partitions"] = partitions
1239
+ if root_vertex is not None and "root_vertex" not in layout_config:
1240
+ layout_config["root_vertex"] = root_vertex
1241
+
973
1242
  self._layout = _determine_graph_layout(
974
1243
  self._graph,
975
1244
  layout=layout,
976
1245
  layout_scale=layout_scale,
977
1246
  layout_config=layout_config,
978
- partitions=partitions,
979
- root_vertex=root_vertex,
980
1247
  )
1248
+
981
1249
  for v in self.vertices:
982
1250
  self[v].move_to(self._layout[v])
983
1251
  return self
@@ -1233,13 +1501,16 @@ class Graph(GenericGraph):
1233
1501
  VERTEX_CONF = {"radius": 0.25, "color": BLUE_B, "fill_opacity": 1}
1234
1502
 
1235
1503
  def expand_vertex(self, g, vertex_id: str, depth: int):
1236
- new_vertices = [f"{vertex_id}/{i}" for i in range(self.CHILDREN_PER_VERTEX)]
1504
+ new_vertices = [
1505
+ f"{vertex_id}/{i}" for i in range(self.CHILDREN_PER_VERTEX)
1506
+ ]
1237
1507
  new_edges = [(vertex_id, child_id) for child_id in new_vertices]
1238
1508
  g.add_edges(
1239
1509
  *new_edges,
1240
1510
  vertex_config=self.VERTEX_CONF,
1241
1511
  positions={
1242
- k: g.vertices[vertex_id].get_center() + 0.1 * DOWN for k in new_vertices
1512
+ k: g.vertices[vertex_id].get_center() + 0.1 * DOWN
1513
+ for k in new_vertices
1243
1514
  },
1244
1515
  )
1245
1516
  if depth < self.DEPTH:
@@ -1283,7 +1554,12 @@ class Graph(GenericGraph):
1283
1554
  def update_edges(self, graph):
1284
1555
  for (u, v), edge in graph.edges.items():
1285
1556
  # Undirected graph has a Line edge
1286
- edge.put_start_and_end_on(graph[u].get_center(), graph[v].get_center())
1557
+ edge.set_points_by_ends(
1558
+ graph[u].get_center(),
1559
+ graph[v].get_center(),
1560
+ buff=self._edge_config.get("buff", 0),
1561
+ path_arc=self._edge_config.get("path_arc", 0),
1562
+ )
1287
1563
 
1288
1564
  def __repr__(self: Graph) -> str:
1289
1565
  return f"Undirected graph on {len(self.vertices)} vertices and {len(self.edges)} edges"
@@ -1492,10 +1768,15 @@ class DiGraph(GenericGraph):
1492
1768
  deformed.
1493
1769
  """
1494
1770
  for (u, v), edge in graph.edges.items():
1495
- edge_type = type(edge)
1496
1771
  tip = edge.pop_tips()[0]
1497
- new_edge = edge_type(self[u], self[v], **self._edge_config[(u, v)])
1498
- edge.become(new_edge)
1772
+ # Passing the Mobject instead of the vertex makes the tip
1773
+ # stop on the bounding box of the vertex.
1774
+ edge.set_points_by_ends(
1775
+ graph[u],
1776
+ graph[v],
1777
+ buff=self._edge_config.get("buff", 0),
1778
+ path_arc=self._edge_config.get("path_arc", 0),
1779
+ )
1499
1780
  edge.add_tip(tip)
1500
1781
 
1501
1782
  def __repr__(self: DiGraph) -> str: