InterpolatePy 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,613 @@
1
+ """
2
+ Module for generating and managing trapezoidal velocity profiles for trajectory planning.
3
+ """
4
+
5
+ from collections.abc import Callable
6
+ from dataclasses import dataclass
7
+
8
+ import numpy as np
9
+
10
+
11
+ # Constants
12
+ MIN_POINTS = 2 # Minimum number of points needed for interpolation
13
+ EPSILON = 1e-10 # Small value to prevent division by zero
14
+
15
+
16
+ @dataclass
17
+ class TrajectoryParams:
18
+ """Parameters for trapezoidal trajectory generation."""
19
+
20
+ q0: float
21
+ q1: float
22
+ t0: float = 0.0
23
+ v0: float = 0.0
24
+ v1: float = 0.0
25
+ amax: float | None = None
26
+ vmax: float | None = None
27
+ duration: float | None = None
28
+
29
+
30
+ @dataclass
31
+ class CalculationParams:
32
+ """Parameters for trajectory calculations."""
33
+
34
+ q0: float
35
+ q1: float
36
+ v0: float
37
+ v1: float
38
+ amax: float
39
+
40
+
41
+ @dataclass
42
+ class InterpolationParams:
43
+ """Parameters for multi-point interpolation."""
44
+
45
+ points: list[float]
46
+ v0: float = 0.0
47
+ vn: float = 0.0
48
+ inter_velocities: list[float] | None = None
49
+ times: list[float] | None = None
50
+ amax: float = 10.0
51
+ vmax: float | None = None
52
+
53
+
54
+ class TrapezoidalTrajectory:
55
+ """
56
+ Generate trapezoidal velocity profiles for trajectory planning.
57
+
58
+ This class provides methods to create trapezoidal velocity profiles for various
59
+ trajectory planning scenarios, including single segment trajectories and
60
+ multi-point interpolation.
61
+ """
62
+
63
+ @staticmethod
64
+ def _calculate_duration_based_trajectory(
65
+ params: CalculationParams, duration: float
66
+ ) -> tuple[float, float, float]:
67
+ """
68
+ Calculate trajectory parameters for duration-based constraints.
69
+
70
+ Parameters
71
+ ----------
72
+ params : CalculationParams
73
+ Basic trajectory parameters
74
+ duration : float
75
+ Desired duration
76
+
77
+ Returns
78
+ -------
79
+ tuple[float, float, float]
80
+ Cruise velocity, acceleration time, deceleration time
81
+
82
+ Raises
83
+ ------
84
+ ValueError
85
+ If trajectory is not feasible with given parameters
86
+ """
87
+ q0 = params.q0
88
+ q1 = params.q1
89
+ v0 = params.v0
90
+ v1 = params.v1
91
+ amax = params.amax
92
+ h = q1 - q0
93
+
94
+ # Check feasibility using equation (3.14)
95
+ if amax * h < abs(v0**2 - v1**2) / 2:
96
+ raise ValueError("Trajectory not feasible. Try increasing amax or reducing velocities.")
97
+
98
+ # Check minimum required acceleration (equation 3.15)
99
+ term_under_sqrt = (
100
+ 4 * h**2 - 4 * h * (v0 + v1) * duration + 2 * (v0**2 + v1**2) * duration**2
101
+ )
102
+
103
+ # Ensure term under sqrt is non-negative to avoid numerical issues
104
+ if term_under_sqrt < 0:
105
+ if term_under_sqrt > -EPSILON: # Very close to zero, likely numerical error
106
+ term_under_sqrt = 0
107
+ else:
108
+ raise ValueError(
109
+ "Trajectory not feasible with given duration. Try increasing duration."
110
+ )
111
+
112
+ alim = (2 * h - duration * (v0 + v1) + np.sqrt(term_under_sqrt)) / max(duration**2, EPSILON)
113
+
114
+ if amax < alim:
115
+ # Adjust amax to minimum required
116
+ amax = alim
117
+ print(f"Warning: Using minimum required acceleration: {alim:.4f}")
118
+
119
+ # Calculate constant velocity (vv) from equation in section 3.2.7
120
+ sqrt_term = (
121
+ amax**2 * duration**2 - 4 * amax * h + 2 * amax * (v0 + v1) * duration - (v0 - v1) ** 2
122
+ )
123
+
124
+ # Ensure sqrt term is non-negative
125
+ if sqrt_term < 0:
126
+ if sqrt_term > -EPSILON: # Very close to zero, likely numerical error
127
+ sqrt_term = 0
128
+ else:
129
+ raise ValueError(
130
+ "Numerical issue in trajectory calculation. "
131
+ "The parameters may lead to an invalid trajectory."
132
+ )
133
+
134
+ vv = 0.5 * (v0 + v1 + amax * duration - np.sqrt(sqrt_term))
135
+
136
+ # Calculate acceleration and deceleration times with numerical stability
137
+ ta = (vv - v0) / (amax + EPSILON)
138
+ td = (vv - v1) / (amax + EPSILON)
139
+
140
+ return vv, ta, td
141
+
142
+ @staticmethod
143
+ def _calculate_velocity_based_trajectory(
144
+ params: CalculationParams, vmax: float
145
+ ) -> tuple[float, float, float, float]:
146
+ """
147
+ Calculate trajectory parameters for velocity-based constraints.
148
+
149
+ Parameters
150
+ ----------
151
+ params : CalculationParams
152
+ Basic trajectory parameters
153
+ vmax : float
154
+ Maximum velocity
155
+
156
+ Returns
157
+ -------
158
+ tuple[float, float, float, float]
159
+ Cruise velocity, acceleration time, deceleration time, total duration
160
+
161
+ Raises
162
+ ------
163
+ ValueError
164
+ If trajectory is not feasible with given parameters
165
+ """
166
+ q0 = params.q0
167
+ q1 = params.q1
168
+ v0 = params.v0
169
+ v1 = params.v1
170
+ amax = params.amax
171
+ h = q1 - q0
172
+
173
+ # Determine if vmax is reached (Case 1 or Case 2)
174
+ if h * amax > vmax**2 - (v0**2 + v1**2) / 2:
175
+ # vmax is reached
176
+ vv = vmax
177
+
178
+ # Calculate acceleration and deceleration times with numerical stability
179
+ ta = (vmax - v0) / (amax + EPSILON)
180
+ td = (vmax - v1) / (amax + EPSILON)
181
+
182
+ # Calculate total duration with numerical stability
183
+ v0_vmax_ratio = v0 / max(vmax, EPSILON)
184
+ v1_vmax_ratio = v1 / max(vmax, EPSILON)
185
+
186
+ # Ensure ratios are within valid range to avoid numerical issues
187
+ v0_vmax_ratio = np.clip(v0_vmax_ratio, -1.0 + EPSILON, 1.0 - EPSILON)
188
+ v1_vmax_ratio = np.clip(v1_vmax_ratio, -1.0 + EPSILON, 1.0 - EPSILON)
189
+
190
+ duration = (
191
+ (h / max(vmax, EPSILON))
192
+ + (vmax / (2 * amax + EPSILON)) * (1 - v0_vmax_ratio) ** 2
193
+ + (vmax / (2 * amax + EPSILON)) * (1 - v1_vmax_ratio) ** 2
194
+ )
195
+
196
+ else:
197
+ # vmax is not reached (triangular profile)
198
+ # Ensure the term under sqrt is non-negative
199
+ sqrt_term = h * amax + (v0**2 + v1**2) / 2
200
+ if sqrt_term < 0:
201
+ if sqrt_term > -EPSILON: # Very close to zero, likely numerical error
202
+ sqrt_term = 0
203
+ else:
204
+ raise ValueError(
205
+ "Invalid trajectory parameters. The calculation resulted in "
206
+ "a negative value under a square root."
207
+ )
208
+
209
+ vlim = np.sqrt(sqrt_term)
210
+ vv = vlim
211
+
212
+ # Calculate acceleration and deceleration times with numerical stability
213
+ ta = (vlim - v0) / (amax + EPSILON)
214
+ td = (vlim - v1) / (amax + EPSILON)
215
+
216
+ # Total duration
217
+ duration = ta + td
218
+
219
+ return vv, ta, td, duration
220
+
221
+ @staticmethod
222
+ def generate_trajectory(
223
+ params: TrajectoryParams,
224
+ ) -> tuple[Callable[[float], tuple[float, float, float]], float]:
225
+ """
226
+ Generate a trapezoidal trajectory with non-null initial and final velocities.
227
+
228
+ Handles both positive and negative displacements according to section 3.4.2.
229
+ Uses absolute values for amax and vmax and includes numerical stability enhancements.
230
+
231
+ Parameters
232
+ ----------
233
+ params : TrajectoryParams
234
+ Parameters for trajectory generation including initial and final positions,
235
+ velocities, acceleration and velocity limits, and optional duration.
236
+
237
+ Returns
238
+ -------
239
+ tuple[Callable[[float], tuple[float, float, float]], float]
240
+ A tuple containing:
241
+ - Function that computes position, velocity, and acceleration at time t
242
+ - Duration of trajectory
243
+
244
+ Raises
245
+ ------
246
+ ValueError
247
+ If parameter combination is invalid or trajectory is not feasible
248
+ """
249
+ # Local variables for better readability
250
+ q0 = params.q0
251
+ q1 = params.q1
252
+ t0 = params.t0
253
+ v0 = params.v0
254
+ v1 = params.v1
255
+ amax = params.amax
256
+ vmax = params.vmax
257
+ t_duration = params.duration
258
+
259
+ # Parameter validation
260
+ if amax is None:
261
+ raise ValueError("Maximum acceleration (amax) must be provided")
262
+
263
+ if t_duration is None and vmax is None:
264
+ raise ValueError("Either duration or maximum velocity (vmax) must be provided")
265
+
266
+ # Ensure amax and vmax are positive using absolute values if provided
267
+ amax = abs(amax)
268
+ if vmax is not None:
269
+ vmax = abs(vmax)
270
+
271
+ # Calculate displacement
272
+ h = q1 - q0
273
+
274
+ # Handle negative displacement (q1 < q0) according to section 3.4.2
275
+ invert_results = False
276
+ if h < 0:
277
+ invert_results = True
278
+ # Transform initial and final positions/velocities with opposite signs
279
+ q0, q1 = -q0, -q1
280
+ v0, v1 = -v0, -v1
281
+
282
+ # Recalculate displacement (should be positive now)
283
+ h = q1 - q0
284
+
285
+ # Create calculation parameters
286
+ calc_params = CalculationParams(q0=q0, q1=q1, v0=v0, v1=v1, amax=amax)
287
+
288
+ # Pre-declare variables used in the trajectory function
289
+ ta = 0.0 # Acceleration time
290
+ td = 0.0 # Deceleration time
291
+ vv = 0.0 # Cruise velocity
292
+ duration = 0.0 # Total duration
293
+
294
+ # Determine which case to use based on provided parameters
295
+ if t_duration is not None and vmax is None:
296
+ # Case 1: Preassigned duration and acceleration
297
+ vv, ta, td = TrapezoidalTrajectory._calculate_duration_based_trajectory(
298
+ calc_params, t_duration
299
+ )
300
+ duration = t_duration
301
+
302
+ elif vmax is not None and t_duration is None:
303
+ # Case 2: Preassigned acceleration and velocity
304
+ vv, ta, td, duration = TrapezoidalTrajectory._calculate_velocity_based_trajectory(
305
+ calc_params, vmax
306
+ )
307
+
308
+ else:
309
+ # This should not happen due to the parameter validation above
310
+ raise ValueError(
311
+ "Invalid parameter combination. Provide either (amax, duration) or (amax, vmax)."
312
+ )
313
+
314
+ t1 = t0 + duration
315
+
316
+ # Define the trajectory function
317
+ def trajectory_original(t: float) -> tuple[float, float, float]:
318
+ """
319
+ Evaluate position, velocity, and acceleration at time t.
320
+
321
+ Parameters
322
+ ----------
323
+ t : float
324
+ Time at which to evaluate the trajectory
325
+
326
+ Returns
327
+ -------
328
+ tuple[float, float, float]
329
+ Tuple containing position, velocity, and acceleration at time t
330
+ """
331
+ # Ensure t is within bounds
332
+ t = np.clip(t, t0, t1)
333
+
334
+ # Initialize variables
335
+ position = 0.0
336
+ velocity = 0.0
337
+ acceleration = 0.0
338
+
339
+ # Ensure ta and td are not too small to avoid numerical issues
340
+ ta_safe = max(ta, EPSILON)
341
+ td_safe = max(td, EPSILON)
342
+
343
+ # Calculate trajectory for the given time
344
+ if t0 <= t < t0 + ta_safe:
345
+ # Acceleration phase
346
+ dt = t - t0
347
+ position = q0 + v0 * dt + (vv - v0) / (2 * ta_safe) * dt**2
348
+ velocity = v0 + (vv - v0) / ta_safe * dt
349
+ acceleration = (vv - v0) / ta_safe
350
+ elif t0 + ta_safe <= t < t1 - td_safe:
351
+ # Constant velocity phase
352
+ position = q0 + v0 * ta_safe / 2 + vv * (t - t0 - ta_safe / 2)
353
+ velocity = vv
354
+ acceleration = 0
355
+ elif t1 - td_safe <= t <= t1:
356
+ # Deceleration phase
357
+ dt = t1 - t
358
+ position = q1 - v1 * dt - (vv - v1) / (2 * td_safe) * dt**2
359
+ velocity = v1 + (vv - v1) / td_safe * dt
360
+ acceleration = -(vv - v1) / td_safe
361
+
362
+ return position, velocity, acceleration
363
+
364
+ # If we had a negative displacement, invert the resulting profiles
365
+ if invert_results:
366
+
367
+ def trajectory(t: float) -> tuple[float, float, float]:
368
+ pos, vel, acc = trajectory_original(t)
369
+ # Invert signs to transform back according to equation (3.33)
370
+ return -pos, -vel, -acc
371
+
372
+ else:
373
+ trajectory = trajectory_original
374
+
375
+ return trajectory, duration
376
+
377
+ @staticmethod
378
+ def calculate_heuristic_velocities(
379
+ q_list: list[float],
380
+ v0: float,
381
+ vn: float,
382
+ v_max: float | None = None,
383
+ amax: float | None = None,
384
+ ) -> list[float]:
385
+ """
386
+ Calculate velocities based on height differences with multiple options for
387
+ heuristic velocity calculation.
388
+
389
+ Parameters
390
+ ----------
391
+ q_list : list[float]
392
+ List of height values [q0, q1, ..., qn]
393
+ v0 : float
394
+ Initial velocity (assigned)
395
+ vn : float
396
+ Final velocity (assigned)
397
+ v_max : float | None, optional
398
+ Maximum velocity value (positive magnitude)
399
+ amax : float | None, optional
400
+ Maximum acceleration (needed if v_max is not provided)
401
+
402
+ Returns
403
+ -------
404
+ list[float]
405
+ Calculated velocities [v0, v1, ..., vn]
406
+
407
+ Raises
408
+ ------
409
+ ValueError
410
+ If neither v_max nor amax is provided
411
+ """
412
+ v0 = float(v0)
413
+ vn = float(vn)
414
+ # Calculate height differences h_k = q_k - q_(k-1)
415
+ h_values = [float(q_list[k] - q_list[k - 1]) for k in range(1, len(q_list))]
416
+
417
+ # If v_max is not provided, compute it heuristically
418
+ if v_max is None:
419
+ if amax is None:
420
+ raise ValueError("Either v_max or amax must be provided")
421
+
422
+ # Ensure amax is positive
423
+ amax = abs(amax)
424
+
425
+ # OPTION 1: Time-Based Approach
426
+ # Estimate a reasonable total duration for the path and derive velocity
427
+ total_distance = sum(abs(h) for h in h_values)
428
+ estimated_duration = np.sqrt(2 * total_distance / amax) # From acceleration equation
429
+ v_max = total_distance / estimated_duration * 0.75 # 75% of average velocity
430
+
431
+ # OPTION 2: Segment-Optimized Approach
432
+ # Calculate optimal velocity for each segment based on its length
433
+ segment_velocities = []
434
+ for h in h_values:
435
+ # Calculate velocity that allows comfortable acceleration/deceleration
436
+ segment_length = abs(h)
437
+ # Distance to accelerate from 0 to v and decelerate back to 0 is (v^2)/a
438
+ # We want this to be less than the segment length, solving for v:
439
+ v_segment = np.sqrt(amax * segment_length / 2)
440
+ segment_velocities.append(v_segment)
441
+
442
+ # Choose a velocity that works well for all segments
443
+ v_max_segments = (
444
+ min(segment_velocities) * 0.8
445
+ ) # 80% of minimum optimal segment velocity
446
+
447
+ # OPTION 3: Curvature-Based Approach
448
+ # Look at changes in direction to determine velocity
449
+ direction_changes = []
450
+ for i in range(len(h_values) - 1):
451
+ # Calculate angle between consecutive segments
452
+ if h_values[i] * h_values[i + 1] < 0: # Direction change
453
+ direction_changes.append(1.0) # Full direction change
454
+ else:
455
+ # Calculate relative change in slope
456
+ rel_change = abs(h_values[i + 1] - h_values[i]) / (
457
+ abs(h_values[i]) + abs(h_values[i + 1])
458
+ )
459
+ direction_changes.append(rel_change)
460
+
461
+ # More direction changes or sharper changes suggest lower velocity
462
+ avg_change = sum(direction_changes) / max(len(direction_changes), 1)
463
+ v_max_curvature = np.sqrt(amax * total_distance / (len(h_values) + 5 * avg_change))
464
+
465
+ # Choose the minimum of all approaches for safety
466
+ v_max = min(v_max, v_max_segments, v_max_curvature)
467
+
468
+ # Ensure v_max is valid
469
+ if v_max is None or v_max <= 0:
470
+ raise ValueError("Failed to calculate a valid maximum velocity")
471
+
472
+ # Initialize velocities array with v0 as the first element
473
+ velocities = [v0]
474
+
475
+ # Calculate intermediate velocities (v1 to v_(n-1))
476
+ for k in range(len(h_values) - 1):
477
+ if np.sign(h_values[k]) != np.sign(h_values[k + 1]):
478
+ velocities.append(0.0)
479
+ else:
480
+ velocities.append(float(np.sign(h_values[k]) * v_max))
481
+
482
+ # Add the final velocity vn
483
+ velocities.append(vn)
484
+
485
+ return velocities
486
+
487
+ @classmethod
488
+ def interpolate_waypoints(
489
+ cls, params: InterpolationParams
490
+ ) -> tuple[Callable[[float], tuple[float, float, float]], float]:
491
+ """
492
+ Generate a trajectory through a sequence of points using trapezoidal velocity profiles.
493
+
494
+ Supports both positive and negative displacements.
495
+
496
+ Parameters
497
+ ----------
498
+ params : InterpolationParams
499
+ Parameters for interpolation including points, velocities,
500
+ times, and motion constraints.
501
+
502
+ Returns
503
+ -------
504
+ tuple[Callable[[float], tuple[float, float, float]], float]
505
+ A tuple containing:
506
+ - Function that returns position, velocity, and acceleration at any time t
507
+ - Total duration of the trajectory
508
+
509
+ Raises
510
+ ------
511
+ ValueError
512
+ If less than two points are provided or if invalid velocity counts are provided
513
+ """
514
+ # Ensure input is valid
515
+ if len(params.points) < MIN_POINTS:
516
+ raise ValueError("At least two points are required for interpolation")
517
+
518
+ # Calculate intermediate velocities if not provided
519
+ if params.inter_velocities is None:
520
+ velocities = cls.calculate_heuristic_velocities(
521
+ params.points, params.v0, params.vn, params.vmax, params.amax
522
+ )
523
+ elif len(params.inter_velocities) != len(params.points) - 2:
524
+ raise ValueError(
525
+ f"Expected {len(params.points) - 2} intermediate velocities, "
526
+ f"got {len(params.inter_velocities)}"
527
+ )
528
+ else:
529
+ # Use provided velocities
530
+ velocities = [params.v0]
531
+ velocities.extend(params.inter_velocities)
532
+ velocities.append(params.vn)
533
+
534
+ # If vmax was computed in the heuristic, use it for the trajectories
535
+ vmax = params.vmax
536
+ if vmax is None and params.inter_velocities is None:
537
+ # Extract the computed vmax from the heuristic (maximum absolute velocity)
538
+ computed_vmax = max(abs(v) for v in velocities)
539
+ vmax = computed_vmax
540
+
541
+ # Initialize containers for combined trajectory
542
+ all_trajectories = []
543
+ cumulative_time = 0.0
544
+ segment_end_times = [0.0] # Start with initial time
545
+
546
+ # Generate individual segment trajectories
547
+ for i in range(len(params.points) - 1):
548
+ q0 = params.points[i]
549
+ q1 = params.points[i + 1]
550
+ v_start = velocities[i]
551
+ v_end = velocities[i + 1]
552
+
553
+ if params.times is None:
554
+ # Calculate trajectory with velocity/acceleration constraints
555
+ traj_params = TrajectoryParams(
556
+ q0=q0,
557
+ q1=q1,
558
+ t0=cumulative_time,
559
+ v0=v_start,
560
+ v1=v_end,
561
+ amax=params.amax,
562
+ vmax=vmax,
563
+ )
564
+ traj_func, segment_duration = cls.generate_trajectory(traj_params)
565
+ else:
566
+ # Use specified time for this segment
567
+ segment_duration = params.times[i + 1] - params.times[i]
568
+ traj_params = TrajectoryParams(
569
+ q0=q0,
570
+ q1=q1,
571
+ t0=cumulative_time,
572
+ v0=v_start,
573
+ v1=v_end,
574
+ amax=params.amax,
575
+ duration=segment_duration,
576
+ )
577
+ traj_func, _ = cls.generate_trajectory(traj_params)
578
+
579
+ cumulative_time += segment_duration
580
+ segment_end_times.append(cumulative_time)
581
+ all_trajectories.append(traj_func)
582
+
583
+ # Total duration of the trajectory
584
+ total_duration = cumulative_time
585
+
586
+ # Function to evaluate trajectory at any time t
587
+ def trajectory_function(t: float) -> tuple[float, float, float]:
588
+ """
589
+ Evaluate the trajectory at time t.
590
+
591
+ Parameters
592
+ ----------
593
+ t : float
594
+ Time at which to evaluate the trajectory
595
+
596
+ Returns
597
+ -------
598
+ tuple[float, float, float]
599
+ Tuple containing position, velocity, and acceleration at time t
600
+ """
601
+ # Clip time to valid range
602
+ t = np.clip(t, 0.0, total_duration)
603
+
604
+ # Determine which segment this time belongs to
605
+ segment_idx = np.searchsorted(segment_end_times, t, side="right") - 1
606
+
607
+ if segment_idx < len(all_trajectories):
608
+ position, velocity, acceleration = all_trajectories[segment_idx](t)
609
+ return position, velocity, acceleration
610
+ # If beyond the end, return final position with zero velocity and acceleration
611
+ return params.points[-1], 0.0, 0.0
612
+
613
+ return trajectory_function, total_duration
@@ -0,0 +1,96 @@
1
+ import numpy as np
2
+
3
+
4
+ def solve_tridiagonal(
5
+ lower_diagonal: np.ndarray,
6
+ main_diagonal: np.ndarray,
7
+ upper_diagonal: np.ndarray,
8
+ right_hand_side: np.ndarray,
9
+ ) -> np.ndarray:
10
+ """
11
+ Solve a tridiagonal system using the Thomas algorithm.
12
+
13
+ This function solves the equation Ax = b where A is a tridiagonal matrix.
14
+ The system is solved efficiently using the Thomas algorithm (also known as
15
+ the tridiagonal matrix algorithm).
16
+
17
+ Parameters
18
+ ----------
19
+ lower_diagonal : np.ndarray
20
+ Lower diagonal elements (first element is not used).
21
+ Must have the same length as main_diagonal.
22
+ main_diagonal : np.ndarray
23
+ Main diagonal elements.
24
+ upper_diagonal : np.ndarray
25
+ Upper diagonal elements (last element is not used).
26
+ Must have the same length as main_diagonal.
27
+ right_hand_side : np.ndarray
28
+ Right-hand side vector of the equation.
29
+
30
+ Returns
31
+ -------
32
+ np.ndarray
33
+ Solution vector x.
34
+
35
+ Raises
36
+ ------
37
+ ValueError
38
+ If a pivot is zero during forward elimination.
39
+
40
+ Examples
41
+ --------
42
+ >>> import numpy as np
43
+ >>> a = np.array([0, 1, 2, 3]) # Lower diagonal (a[0] is not used)
44
+ >>> b = np.array([2, 3, 4, 5]) # Main diagonal
45
+ >>> c = np.array([1, 2, 3, 0]) # Upper diagonal (c[-1] is not used)
46
+ >>> d = np.array([1, 2, 3, 4]) # Right hand side
47
+ >>> x = solve_tridiagonal(a, b, c, d)
48
+ >>> print(x)
49
+
50
+ Notes
51
+ -----
52
+ The Thomas algorithm is a specialized form of Gaussian elimination for
53
+ tridiagonal systems. It is much more efficient than general Gaussian
54
+ elimination, with a time complexity of O(n) instead of O(n³).
55
+
56
+ The algorithm consists of two phases:
57
+ 1. Forward elimination to transform the matrix into an upper triangular form
58
+ 2. Back substitution to find the solution
59
+
60
+ For a system where the matrix A is:
61
+ [b₀ c₀ 0 0 0]
62
+ [a₁ b₁ c₁ 0 0]
63
+ [0 a₂ b₂ c₂ 0]
64
+ [0 0 a₃ b₃ c₃]
65
+ [0 0 0 a₄ b₄]
66
+
67
+ References
68
+ ----------
69
+ .. [1] Thomas, L.H. (1949). "Elliptic Problems in Linear Differential
70
+ Equations over a Network". Watson Sci. Comput. Lab Report.
71
+ """
72
+ n = len(right_hand_side)
73
+
74
+ # Create copies of the input arrays to avoid modifying them
75
+ a_copy = np.array(lower_diagonal, dtype=float)
76
+ b_copy = np.array(main_diagonal, dtype=float)
77
+ c_copy = np.array(upper_diagonal, dtype=float)
78
+ d_copy = np.array(right_hand_side, dtype=float)
79
+
80
+ # Check for zero pivot
81
+ if b_copy[0] == 0:
82
+ raise ValueError("Pivot cannot be zero. The system cannot be solved with this method.")
83
+
84
+ # Forward elimination
85
+ for k in range(1, n):
86
+ m = a_copy[k] / b_copy[k - 1]
87
+ b_copy[k] -= m * c_copy[k - 1]
88
+ d_copy[k] -= m * d_copy[k - 1]
89
+
90
+ # Back substitution
91
+ x = np.zeros(n)
92
+ x[n - 1] = d_copy[n - 1] / b_copy[n - 1]
93
+ for k in range(n - 2, -1, -1):
94
+ x[k] = (d_copy[k] - c_copy[k] * x[k + 1]) / b_copy[k]
95
+
96
+ return x