knit-graphs 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,8 +8,7 @@ from __future__ import annotations
8
8
 
9
9
  import os
10
10
  import sys
11
- from collections.abc import Iterable
12
- from typing import TypedDict, cast
11
+ from typing import Generic, TypedDict, TypeVar, cast
13
12
 
14
13
  import plotly.io as pio
15
14
  from networkx import DiGraph
@@ -21,13 +20,15 @@ from knit_graphs.Knit_Graph import Knit_Graph
21
20
  from knit_graphs.Loop import Loop
22
21
  from knit_graphs.Pull_Direction import Pull_Direction
23
22
 
23
+ LoopT = TypeVar("LoopT", bound=Loop)
24
24
 
25
- class TraceData(TypedDict):
25
+
26
+ class TraceData(TypedDict, Generic[LoopT]):
26
27
  """Typing specification for the dictionaries passes as traces to Plotly"""
27
28
 
28
29
  x: list[float | None]
29
30
  y: list[float | None]
30
- edge: list[tuple[Loop, Loop]]
31
+ edge: list[tuple[LoopT, LoopT]]
31
32
  is_start: list[bool]
32
33
 
33
34
 
@@ -49,7 +50,7 @@ def configure_plotly_environment() -> None:
49
50
  configure_plotly_environment()
50
51
 
51
52
 
52
- class Knit_Graph_Visualizer:
53
+ class Knit_Graph_Visualizer(Generic[LoopT]):
53
54
  """A class used to visualize a knit graph using the plotly graph objects library.
54
55
 
55
56
  This class converts knit graph data structures into interactive 2D visualizations by calculating loop positions,
@@ -71,7 +72,7 @@ class Knit_Graph_Visualizer:
71
72
 
72
73
  def __init__(
73
74
  self,
74
- knit_graph: Knit_Graph,
75
+ knit_graph: Knit_Graph[LoopT],
75
76
  first_course_index: int = 0,
76
77
  top_course_index: int | None = None,
77
78
  start_on_left: bool = True,
@@ -92,18 +93,18 @@ class Knit_Graph_Visualizer:
92
93
  self.balance_by_base_width: bool = balance_by_base_width
93
94
  self.start_on_left: bool = start_on_left
94
95
  self.knit_graph: Knit_Graph = knit_graph
95
- self.courses: list[Course] = knit_graph.get_courses()
96
+ self.courses: list[Course[LoopT]] = knit_graph.get_courses()
96
97
  if top_course_index is None:
97
98
  top_course_index = len(self.courses)
98
99
  self.top_course_index: int = top_course_index
99
100
  self.first_course_index: int = first_course_index
100
101
  self.base_width: float = float(len(self.courses[first_course_index])) # Updates when creating base course.
101
102
  self.base_left: float = 0.0 # Updates when creating the base course.
102
- self.loops_to_course: dict[Loop, Course] = {}
103
+ self.loops_to_course: dict[LoopT, Course] = {}
103
104
  for course in self.courses:
104
105
  self.loops_to_course.update({loop: course for loop in course})
105
106
  self.data_graph: DiGraph = DiGraph()
106
- self._loops_need_placement: set[Loop] = set()
107
+ self._loops_need_placement: set[LoopT] = set()
107
108
  self._loop_markers: list[Scatter] = []
108
109
  self._yarn_traces: list[Scatter] = []
109
110
  self._top_knit_trace_data: TraceData = {
@@ -332,11 +333,7 @@ class Knit_Graph_Visualizer:
332
333
 
333
334
  def _add_cable_edges(self) -> None:
334
335
  """Add all stitch edges that are involved in cable crossings to the appropriate trace data."""
335
- for (
336
- left_loop,
337
- right_loop,
338
- ) in self.knit_graph.braid_graph.loop_crossing_graph.edges:
339
- crossing_direction = self.knit_graph.braid_graph.get_crossing(left_loop, right_loop)
336
+ for left_loop, right_loop, crossing_direction in self.knit_graph.braid_graph.edge_iter:
340
337
  for left_parent in left_loop.parent_loops:
341
338
  self._add_stitch_edge(left_parent, left_loop, crossing_direction)
342
339
  for right_parent in right_loop.parent_loops:
@@ -346,7 +343,7 @@ class Knit_Graph_Visualizer:
346
343
  """Add all stitch edges to the visualization trace data based on their type and cable position."""
347
344
  self._add_cable_edges()
348
345
  # Add remaining stitches as though they have no cable crossing.
349
- for u, v in self.knit_graph.stitch_graph.edges:
346
+ for u, v, _ in self.knit_graph.edge_iter:
350
347
  if (
351
348
  not self._stitch_has_position(u, v) # This edge has not been placed
352
349
  and self._loop_has_position(u)
@@ -354,12 +351,10 @@ class Knit_Graph_Visualizer:
354
351
  ): # Both loops do have positions.
355
352
  self._add_stitch_edge(u, v, Crossing_Direction.No_Cross)
356
353
 
357
- def _add_stitch_edge(self, u: Loop, v: Loop, crossing_direction: Crossing_Direction) -> None:
354
+ def _add_stitch_edge(self, u: LoopT, v: LoopT, crossing_direction: Crossing_Direction) -> None:
358
355
  """Add a single stitch edge to the appropriate trace data based on stitch type and cable crossing."""
359
356
  pull_direction = self.knit_graph.get_pull_direction(u, v)
360
- if pull_direction is None:
361
- return # No edge between these loops
362
- elif pull_direction is Pull_Direction.BtF: # Knit Stitch:
357
+ if pull_direction is Pull_Direction.BtF: # Knit Stitch:
363
358
  if crossing_direction is Crossing_Direction.Over_Right:
364
359
  trace_data = self._top_knit_trace_data
365
360
  elif crossing_direction is Crossing_Direction.Under_Right:
@@ -448,19 +443,13 @@ class Knit_Graph_Visualizer:
448
443
 
449
444
  def _shift_knit_purl(self, shift: float = 0.1) -> None:
450
445
  """Adjust the horizontal position of loops to visually distinguish knit from purl stitches."""
451
- has_knits = any(
452
- self.knit_graph.get_pull_direction(u, v) is Pull_Direction.BtF
453
- for u, v in self.knit_graph.stitch_graph.edges
454
- )
455
- has_purls = any(
456
- self.knit_graph.get_pull_direction(u, v) is Pull_Direction.FtB
457
- for u, v in self.knit_graph.stitch_graph.edges
458
- )
446
+ has_knits = any(pd is Pull_Direction.BtF for _u, _v, pd in self.knit_graph.edge_iter)
447
+ has_purls = any(pd is Pull_Direction.FtB for _u, _v, pd in self.knit_graph.edge_iter)
459
448
  if not (has_knits and has_purls):
460
449
  return # Don't make any changes, because all stitches are of the same type.
461
450
  yarn_over_align = set()
462
451
  for loop in self.data_graph.nodes:
463
- if not loop.has_parent_loops(): # Yarn-over
452
+ if not loop.has_parent_loops: # Yarn-over
464
453
  if self.knit_graph.has_child_loop(loop): # Align yarn-overs with one child to its child
465
454
  yarn_over_align.add(loop)
466
455
  continue # Don't shift yarn-overs
@@ -477,7 +466,7 @@ class Knit_Graph_Visualizer:
477
466
 
478
467
  for loop in yarn_over_align:
479
468
  child_loop = self.knit_graph.get_child_loop(loop)
480
- assert isinstance(child_loop, Loop)
469
+ assert child_loop is not None
481
470
  self._set_x_of_loop(loop, self._get_x_of_loop(child_loop))
482
471
 
483
472
  def _shift_loops_by_float_alignment(self, float_increment: float = 0.25) -> None:
@@ -501,11 +490,11 @@ class Knit_Graph_Visualizer:
501
490
  back_loop, self._get_y_of_loop(back_loop) + float_increment
502
491
  ) # shift loop up to show it is behind the float.
503
492
 
504
- def _get_course_of_loop(self, loop: Loop) -> Course:
493
+ def _get_course_of_loop(self, loop: LoopT) -> Course[LoopT]:
505
494
  """Get the course (horizontal row) that contains the specified loop."""
506
495
  return self.loops_to_course[loop]
507
496
 
508
- def _place_loop(self, loop: Loop, x: float, y: float) -> None:
497
+ def _place_loop(self, loop: LoopT, x: float, y: float) -> None:
509
498
  """Add a loop to the visualization data graph at the specified coordinates."""
510
499
  if self._loop_has_position(loop):
511
500
  self._set_x_of_loop(loop, x)
@@ -513,39 +502,39 @@ class Knit_Graph_Visualizer:
513
502
  else:
514
503
  self.data_graph.add_node(loop, x=x, y=y)
515
504
 
516
- def _set_x_of_loop(self, loop: Loop, x: float) -> None:
505
+ def _set_x_of_loop(self, loop: LoopT, x: float) -> None:
517
506
  """Update the x coordinate of a loop that already exists in the visualization data graph."""
518
507
  if self._loop_has_position(loop):
519
508
  self.data_graph.nodes[loop]["x"] = x
520
509
  else:
521
510
  raise KeyError(f"Loop {loop} is not in the data graph")
522
511
 
523
- def _set_y_of_loop(self, loop: Loop, y: float) -> None:
512
+ def _set_y_of_loop(self, loop: LoopT, y: float) -> None:
524
513
  """Update the y coordinate of a loop that already exists in the visualization data graph."""
525
514
  if self._loop_has_position(loop):
526
515
  self.data_graph.nodes[loop]["y"] = y
527
516
  else:
528
517
  raise KeyError(f"Loop {loop} is not in the data graph")
529
518
 
530
- def _get_x_of_loop(self, loop: Loop) -> float:
519
+ def _get_x_of_loop(self, loop: LoopT) -> float:
531
520
  """Get the x coordinate of a loop from the visualization data graph."""
532
521
  if self._loop_has_position(loop):
533
522
  return float(self.data_graph.nodes[loop]["x"])
534
523
  else:
535
524
  raise KeyError(f"Loop {loop} is not in the data graph")
536
525
 
537
- def _get_y_of_loop(self, loop: Loop) -> float:
526
+ def _get_y_of_loop(self, loop: LoopT) -> float:
538
527
  """Get the y coordinate of a loop from the visualization data graph."""
539
528
  if self._loop_has_position(loop):
540
529
  return float(self.data_graph.nodes[loop]["y"])
541
530
  else:
542
531
  raise KeyError(f"Loop {loop} is not in the data graph")
543
532
 
544
- def _loop_has_position(self, loop: Loop) -> bool:
533
+ def _loop_has_position(self, loop: LoopT) -> bool:
545
534
  """Check if a loop has been positioned in the visualization data graph."""
546
535
  return bool(self.data_graph.has_node(loop))
547
536
 
548
- def _stitch_has_position(self, u: Loop, v: Loop) -> bool:
537
+ def _stitch_has_position(self, u: LoopT, v: LoopT) -> bool:
549
538
  """Check if a stitch edge between two loops has been added to the visualization data graph."""
550
539
  return bool(self.data_graph.has_edge(u, v))
551
540
 
@@ -559,7 +548,7 @@ class Knit_Graph_Visualizer:
559
548
  self._balance_course(course)
560
549
  y += course_spacing # Shift y coordinate up with each course
561
550
 
562
- def _swap_loops_in_cables(self, course: Course) -> None:
551
+ def _swap_loops_in_cables(self, course: Course[LoopT]) -> None:
563
552
  """Swap the horizontal positions of loops involved in cable crossings within a course."""
564
553
  for left_loop in course:
565
554
  for right_loop in self.knit_graph.braid_graph.left_crossing_loops(left_loop):
@@ -571,7 +560,7 @@ class Knit_Graph_Visualizer:
571
560
  self._set_x_of_loop(left_loop, self._get_x_of_loop(right_loop))
572
561
  self._set_x_of_loop(right_loop, left_x)
573
562
 
574
- def _place_loops_by_parents(self, course: Course, y: float) -> None:
563
+ def _place_loops_by_parents(self, course: Course[LoopT], y: float) -> None:
575
564
  """Position loops in a course based on the average position of their parent loops."""
576
565
  for _x, loop in enumerate(course):
577
566
  self._set_loop_x_by_parent_average(loop, y)
@@ -584,7 +573,7 @@ class Knit_Graph_Visualizer:
584
573
  assert len(self._loops_need_placement) == 0, f"Loops {self._loops_need_placement} remain unplaced."
585
574
  # A loops past the first course should have at least one yarn neighbor to place them.
586
575
 
587
- def _set_loop_x_by_parent_average(self, loop: Loop, y: float) -> None:
576
+ def _set_loop_x_by_parent_average(self, loop: LoopT, y: float) -> None:
588
577
  """Set the x coordinate of a loop based on the weighted average position of its parent loops."""
589
578
  if len(loop.parent_loops) == 0:
590
579
  self._loops_need_placement.add(loop)
@@ -604,7 +593,7 @@ class Knit_Graph_Visualizer:
604
593
  x = sum(parent_positions.keys()) / sum(parent_positions.values())
605
594
  self._place_loop(loop, x=x, y=y)
606
595
 
607
- def _set_loop_between_yarn_neighbors(self, loop: Loop, y: float, spacing: float = 1.0) -> bool:
596
+ def _set_loop_between_yarn_neighbors(self, loop: LoopT, y: float, spacing: float = 1.0) -> bool:
608
597
  """Position a loop based on the average position of its neighboring loops along the yarn."""
609
598
  spacing = abs(spacing) # Ensure spacing is positive.
610
599
  x_neighbors = []
@@ -647,13 +636,13 @@ class Knit_Graph_Visualizer:
647
636
  self.base_width = max_x - self.base_left
648
637
 
649
638
  def _get_base_round_course_positions(
650
- self, base_course: Course, loop_space: float = 1.0, back_shift: float = 0.5
639
+ self, base_course: Course[LoopT], loop_space: float = 1.0, back_shift: float = 0.5
651
640
  ) -> None:
652
641
  """Position loops in the base course for circular/tube knitting structure."""
653
642
  split_index = len(base_course) // 2 # Split the course in half to form a tube.
654
- front_loops: list[Loop] = cast(list[Loop], base_course[:split_index])
655
- front_set: set[Loop] = set(front_loops)
656
- back_loops: list[Loop] = cast(list[Loop], base_course[split_index:])
643
+ front_loops = base_course[:split_index]
644
+ front_set = set(front_loops)
645
+ back_loops = base_course[split_index:]
657
646
  if self.start_on_left:
658
647
  back_loops = [*reversed(back_loops)]
659
648
  else:
@@ -677,15 +666,12 @@ class Knit_Graph_Visualizer:
677
666
  else:
678
667
  self._place_loop(back_loop, x=(x * loop_space) - back_shift, y=0)
679
668
 
680
- def _get_base_row_course_positions(self, base_course: Course, loop_space: float = 1.0) -> None:
669
+ def _get_base_row_course_positions(self, base_course: Course[LoopT], loop_space: float = 1.0) -> None:
681
670
  """Position loops in the base course for flat/row knitting structure."""
682
- loops: Iterable[Loop] = list(base_course)
683
- if not self.start_on_left:
684
- loops = reversed(base_course)
685
- for x, loop in enumerate(loops):
671
+ for x, loop in enumerate(base_course if self.start_on_left else reversed(base_course)):
686
672
  self._place_loop(loop, x=x * loop_space, y=0)
687
673
 
688
- def _left_align_course(self, course: Course) -> None:
674
+ def _left_align_course(self, course: Course[LoopT]) -> None:
689
675
  """Align the leftmost loop of a course to x=0 if left alignment is enabled."""
690
676
  if self.left_zero_align:
691
677
  current_left = min(self._get_x_of_loop(loop) for loop in course)
@@ -693,21 +679,23 @@ class Knit_Graph_Visualizer:
693
679
  for loop in course:
694
680
  self._set_x_of_loop(loop, self._get_x_of_loop(loop) - current_left)
695
681
 
696
- def _balance_course(self, course: Course) -> None:
682
+ def _balance_course(self, course: Course[LoopT]) -> None:
697
683
  """Scale the width of a course to match the base course width if balancing is enabled."""
698
684
  current_left = min(self._get_x_of_loop(loop) for loop in course)
699
685
  max_x = max(self._get_x_of_loop(loop) for loop in course)
700
686
  course_width = max_x - current_left
701
687
  if self.balance_by_base_width and course_width != self.base_width:
702
688
 
703
- def _target_distance_from_left(l: Loop) -> float:
689
+ def _target_distance_from_left(l: LoopT) -> float:
704
690
  current_distance_from_left = self._get_x_of_loop(l) - current_left
705
691
  return (current_distance_from_left * self.base_width) / course_width
706
692
 
707
693
  for loop in course:
708
694
  self._set_x_of_loop(loop, _target_distance_from_left(loop) + current_left)
709
695
 
710
- def x_coordinate_differences(self, other: Knit_Graph_Visualizer) -> dict[Loop, tuple[float | None, float | None]]:
696
+ def x_coordinate_differences(
697
+ self, other: Knit_Graph_Visualizer[LoopT]
698
+ ) -> dict[LoopT, tuple[float | None, float | None]]:
711
699
  """Find the differences in x-coordinates between two knitgraph visualizations. Used for testing and comparing visualization results.
712
700
 
713
701
  Args:
@@ -721,28 +709,28 @@ class Knit_Graph_Visualizer:
721
709
  ** The second value of each tuple is the x-coordinate of the loop in the other visualization or NOne if the loop is not in that visualization.
722
710
 
723
711
  """
724
- differences: dict[Loop, tuple[float | None, float | None]] = {
725
- cast(Loop, l): (self._get_x_of_loop(l), None)
712
+ differences: dict[LoopT, tuple[float | None, float | None]] = {
713
+ cast(LoopT, l): (self._get_x_of_loop(l), None)
726
714
  for l in self.data_graph.nodes
727
715
  if not other.data_graph.has_node(l)
728
716
  }
729
717
  differences.update(
730
718
  {
731
- cast(Loop, l): (None, other._get_x_of_loop(l))
719
+ cast(LoopT, l): (None, other._get_x_of_loop(l))
732
720
  for l in other.data_graph.nodes
733
721
  if not self.data_graph.has_node(l)
734
722
  }
735
723
  )
736
724
  differences.update(
737
725
  {
738
- cast(Loop, l): (self._get_x_of_loop(l), other._get_x_of_loop(l))
726
+ cast(LoopT, l): (self._get_x_of_loop(l), other._get_x_of_loop(l))
739
727
  for l in self.data_graph.nodes
740
728
  if other.data_graph.has_node(l) and self._get_x_of_loop(l) != other._get_x_of_loop(l)
741
729
  }
742
730
  )
743
731
  return differences
744
732
 
745
- def y_coordinate_differences(self, other: Knit_Graph_Visualizer) -> dict[Loop, tuple[float | None, float | None]]:
733
+ def y_coordinate_differences(self, other: Knit_Graph_Visualizer) -> dict[LoopT, tuple[float | None, float | None]]:
746
734
  """Find the differences in y-coordinates between two knitgraph visualizations. Used for testing and comparing visualization results.
747
735
 
748
736
  Args:
@@ -756,21 +744,21 @@ class Knit_Graph_Visualizer:
756
744
  ** The second value of each tuple is the y-coordinate of the loop in the other visualization or NOne if the loop is not in that visualization.
757
745
 
758
746
  """
759
- differences: dict[Loop, tuple[float | None, float | None]] = {
760
- cast(Loop, l): (self._get_y_of_loop(l), None)
747
+ differences: dict[LoopT, tuple[float | None, float | None]] = {
748
+ cast(LoopT, l): (self._get_y_of_loop(l), None)
761
749
  for l in self.data_graph.nodes
762
750
  if not other.data_graph.has_node(l)
763
751
  }
764
752
  differences.update(
765
753
  {
766
- cast(Loop, l): (None, other._get_y_of_loop(l))
754
+ cast(LoopT, l): (None, other._get_y_of_loop(l))
767
755
  for l in other.data_graph.nodes
768
756
  if not self.data_graph.has_node(l)
769
757
  }
770
758
  )
771
759
  differences.update(
772
760
  {
773
- cast(Loop, l): (self._get_y_of_loop(l), other._get_y_of_loop(l))
761
+ cast(LoopT, l): (self._get_y_of_loop(l), other._get_y_of_loop(l))
774
762
  for l in self.data_graph.nodes
775
763
  if other.data_graph.has_node(l) and self._get_y_of_loop(l) != other._get_y_of_loop(l)
776
764
  }