pandas-survey-toolkit 1.0.3__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,760 +1,198 @@
1
- import textwrap
2
- from typing import List, Tuple
3
-
4
- import altair as alt
5
- import matplotlib.colors as mcolors
6
- import networkx as nx
7
- import numpy as np
8
- import pandas as pd
9
- from IPython.display import HTML, display
10
- from matplotlib import pyplot as plt
11
- from pyvis.network import Network
12
- from scipy import stats
13
- from sklearn.preprocessing import MinMaxScaler
14
-
15
-
16
- def cluster_heatmap_plot(df: pd.DataFrame, x: str, y: List[str], max_width: int = 75):
17
- """
18
- Create a heatmap visualization of Likert scale responses grouped by clusters.
19
-
20
- This function generates an interactive Altair visualization showing the distribution
21
- of positive and negative responses across different clusters for each question.
22
- The visualization consists of two parts:
23
- 1. A bar chart showing the number of respondents in each cluster
24
- 2. A heatmap showing the sentiment distribution for each question by cluster
25
-
26
- Parameters
27
- ----------
28
- df : pd.DataFrame
29
- The DataFrame containing the clustered data and encoded Likert responses.
30
- Should include a cluster column and encoded Likert columns.
31
-
32
- x : str
33
- The name of the column containing cluster IDs (e.g., 'question_cluster_id').
34
-
35
- y : List[str]
36
- List of column names containing the encoded Likert responses.
37
- These should typically be columns with values -1, 0, 1 representing
38
- negative, neutral, and positive responses.
39
-
40
- max_width : int, default=75
41
- Maximum width for wrapping question labels in the visualization.
42
-
43
- Returns
44
- -------
45
- alt.VConcatChart
46
- An Altair chart object combining a bar chart of cluster sizes and
47
- a heatmap of sentiment distribution that can be displayed in a Jupyter notebook
48
- or exported as HTML.
49
-
50
- Notes
51
- -----
52
- The function color-codes the heatmap cells based on the percentage of
53
- positive and negative responses, with green representing positive sentiment,
54
- red representing negative sentiment, and varying shades for mixed responses.
55
-
56
- The encoded Likert columns (y parameter) should contain values that are encoded as:
57
- * 1 for positive responses
58
- * 0 for neutral responses
59
- * -1 for negative responses
60
-
61
- Examples
62
- --------
63
- >>> # Assuming df has been processed with cluster_questions
64
- >>> likert_columns = [f"likert_encoded_{q}" for q in questions]
65
- >>> heatmap = cluster_heatmap_plot(df, x="question_cluster_id", y=likert_columns)
66
- >>> display(heatmap)
67
- """
68
- # Convert -1, 0, 1 to percent positive and percent negative
69
- df_positive = df[y].apply(lambda col: (col == 1).astype(int))
70
- df_negative = df[y].apply(lambda col: (col == -1).astype(int))
71
-
72
- # Calculate average percent positive and negative for each cluster and question
73
- heatmap_data_pos = (
74
- df_positive.groupby(df[x])
75
- .mean()
76
- .reset_index()
77
- .melt(id_vars=x, var_name="question", value_name="percent_positive")
78
- )
79
- heatmap_data_neg = (
80
- df_negative.groupby(df[x])
81
- .mean()
82
- .reset_index()
83
- .melt(id_vars=x, var_name="question", value_name="percent_negative")
84
- )
85
-
86
- # Merge positive and negative data
87
- heatmap_data = pd.merge(heatmap_data_pos, heatmap_data_neg, on=[x, "question"])
88
- heatmap_data["percent_neutral"] = (
89
- 1 - heatmap_data["percent_positive"] - heatmap_data["percent_negative"]
90
- )
91
-
92
- # Calculate overall positivity for each cluster
93
- cluster_positivity = (
94
- heatmap_data.groupby(x)["percent_positive"].mean().sort_values(ascending=False)
95
- )
96
- cluster_order = cluster_positivity.index.tolist()
97
-
98
- # Replace underscores with spaces in question labels
99
- heatmap_data["question"] = (
100
- heatmap_data["question"].str.replace("_", " ").str.replace("likert encoded", "")
101
- )
102
-
103
- # Wrap long question labels
104
- wrapped_labels = [
105
- textwrap.fill(label, width=max_width)
106
- for label in heatmap_data["question"].unique()
107
- ]
108
- label_to_wrapped = dict(zip(heatmap_data["question"].unique(), wrapped_labels))
109
- heatmap_data["wrapped_question"] = heatmap_data["question"].map(label_to_wrapped)
110
-
111
- # Define color scale based on percent positive and percent negative
112
- def get_color(pos: float, neg: float) -> Tuple[str, str]:
113
- if pos > 0.6:
114
- return "#1a9641", "white" # Strong positive (green)
115
- elif pos > 0.4:
116
- return "#a6d96a", "black" # Moderate positive (light green)
117
- elif pos > neg:
118
- return "#ffffbf", "black" # Slightly positive (light yellow)
119
- elif neg > 0.6:
120
- return "#d7191c", "white" # Strong negative (red)
121
- elif neg > 0.4:
122
- return "#fdae61", "black" # Moderate negative (orange)
123
- elif neg > pos:
124
- return "#f4a582", "black" # Slightly negative (light red)
125
- else:
126
- return "#f7f7f7", "black" # Neutral (light gray)
127
-
128
- heatmap_data["background_color"], heatmap_data["text_color"] = zip(
129
- *heatmap_data.apply(
130
- lambda row: get_color(row["percent_positive"], row["percent_negative"]),
131
- axis=1,
132
- )
133
- )
134
-
135
- # Calculate chart dimensions
136
- chart_width = 600
137
- row_height = 30
138
- heatmap_height = len(wrapped_labels) * row_height
139
- bar_chart_height = 100
140
-
141
- # Create heatmap
142
- heatmap = (
143
- alt.Chart(heatmap_data)
144
- .mark_rect()
145
- .encode(
146
- x=alt.X(f"{x}:O", title="Cluster ID", sort=cluster_order),
147
- y=alt.Y("wrapped_question:O", title=None, sort=wrapped_labels),
148
- color=alt.Color("background_color:N", scale=None),
149
- tooltip=[
150
- alt.Tooltip(f"{x}:O", title="Cluster ID"),
151
- alt.Tooltip("question:O", title="Question"),
152
- alt.Tooltip("percent_positive:Q", title="% Positive", format=".2%"),
153
- alt.Tooltip("percent_negative:Q", title="% Negative", format=".2%"),
154
- alt.Tooltip("percent_neutral:Q", title="% Neutral", format=".2%"),
155
- ],
156
- )
157
- .properties(
158
- width=chart_width,
159
- height=heatmap_height,
160
- title="Cluster Heatmap: Sentiment Distribution",
161
- )
162
- )
163
-
164
- # Add text labels to heatmap
165
- text = heatmap.mark_text(baseline="middle").encode(
166
- text=alt.Text("percent_positive:Q", format=".0%"),
167
- color=alt.Color("text_color:N", scale=None),
168
- )
169
-
170
- # Create bar chart for cluster counts
171
- cluster_counts = df[x].value_counts().reset_index()
172
- cluster_counts.columns = [x, "count"]
173
- cluster_counts[x] = pd.Categorical(
174
- cluster_counts[x], categories=cluster_order, ordered=True
175
- )
176
- cluster_counts = cluster_counts.sort_values(x)
177
-
178
- bar_chart = (
179
- alt.Chart(cluster_counts)
180
- .mark_bar()
181
- .encode(
182
- x=alt.X(f"{x}:O", title="Cluster ID", sort=cluster_order),
183
- y=alt.Y("count:Q", title="Count"),
184
- tooltip=[
185
- alt.Tooltip(f"{x}:O", title="Cluster ID"),
186
- alt.Tooltip("count:Q", title="Count"),
187
- ],
188
- )
189
- .properties(width=chart_width, height=bar_chart_height, title="Cluster Sizes")
190
- )
191
-
192
- # Add text labels to bar chart
193
- bar_text = bar_chart.mark_text(align="center", baseline="bottom", dy=-5).encode(
194
- text="count:Q"
195
- )
196
-
197
- # Combine bar chart and heatmap using vconcat
198
- combined_chart = (
199
- alt.vconcat((bar_chart + bar_text), (heatmap + text))
200
- .configure_view(strokeWidth=0)
201
- .configure_axis(
202
- labelLimit=350 # Increase label limit to show full wrapped text
203
- )
204
- )
205
-
206
- return combined_chart
207
-
208
-
209
- def create_keyword_graph(df: pd.DataFrame, keyword_column: str) -> nx.DiGraph:
210
- """
211
- Create a NetworkX DiGraph from a DataFrame column containing lists of keywords.
212
-
213
- Parameters:
214
- df (pd.DataFrame): The input DataFrame
215
- keyword_column (str): The name of the column containing keyword lists
216
-
217
- Returns:
218
- nx.DiGraph: A directed graph representing keyword connections
219
- """
220
- G = nx.DiGraph()
221
-
222
- for keywords in df[keyword_column]:
223
- if not isinstance(keywords, list) or len(keywords) == 0:
224
- continue
225
-
226
- # populate nodes first
227
- for i, keyword in enumerate(keywords):
228
-
229
- keyword = keyword.encode("utf-8").decode("utf-8")
230
- # Add node or update node count
231
- if G.has_node(keyword):
232
- G.nodes[keyword]["node_count"] += 1
233
- else:
234
- G.add_node(keyword, node_count=1, title=keyword)
235
- # now popualte edges
236
- if len(keywords) < 2:
237
- continue
238
- for i, keyword in enumerate(keywords):
239
- # Add edge to next keyword if it exists
240
- keyword = keyword.encode("utf-8").decode("utf-8")
241
-
242
- if i < len(keywords) - 1:
243
- next_keyword = keywords[i + 1]
244
- next_keyword = next_keyword.encode("utf-8").decode("utf-8")
245
- if G.has_edge(keyword, next_keyword):
246
- G[keyword][next_keyword]["edge_count"] += 1
247
- else:
248
- G.add_edge(keyword, next_keyword, edge_count=1)
249
-
250
- return G
251
-
252
-
253
- def visualize_keyword_graph(
254
- graph: nx.DiGraph,
255
- output_file: str = None,
256
- min_edge_count: int = 4,
257
- min_node_count: int = 4,
258
- ):
259
- """
260
- Visualize a filtered NetworkX DiGraph using PyViz Network.
261
-
262
- Parameters:
263
- graph (nx.DiGraph): The input graph to visualize
264
- output_file (str): The name of the output HTML file (default: None)
265
- notebook (bool): Whether to display the graph inline in a Jupyter notebook (default: True)
266
- min_edge_count (int): Minimum edge count to include in the visualization (default: 5)
267
- min_node_count (int): Minimum node count to include in the visualization (default: 8)
268
-
269
- Returns:
270
- Network: The PyViz Network object for further customization if needed
271
- """
272
- # Filter the graph based on min_edge_count and min_node_count
273
- filtered_graph = nx.DiGraph()
274
-
275
- for node, data in graph.nodes(data=True):
276
- if data["node_count"] >= min_node_count:
277
- filtered_graph.add_node(node, **data)
278
-
279
- for source, target, data in graph.edges(data=True):
280
- if (
281
- data["edge_count"] >= min_edge_count
282
- and source in filtered_graph.nodes
283
- and target in filtered_graph.nodes
284
- ):
285
- filtered_graph.add_edge(source, target, **data)
286
-
287
- # Remove isolated nodes (nodes with no edges)
288
- filtered_graph.remove_nodes_from(list(nx.isolates(filtered_graph)))
289
-
290
- # Create a PyViz Network object
291
- net = Network(directed=True, width="100%", height="800px")
292
-
293
- # Add nodes to the network
294
- for node, data in filtered_graph.nodes(data=True):
295
- net.add_node(
296
- node,
297
- label=data["title"],
298
- title=f"Keyword: {data['title']}\nCount: {data['node_count']}",
299
- size=data["node_count"] * 1,
300
- ) # Adjust the multiplier as needed
301
-
302
- # Add edges to the network
303
- for source, target, data in filtered_graph.edges(data=True):
304
- net.add_edge(
305
- source,
306
- target,
307
- title=f"Count: {data['edge_count']}",
308
- width=data["edge_count"],
309
- ) # Edge thickness based on count
310
-
311
- # Set some display options
312
- net.set_options(
313
- """
314
- var options = {
315
- "edges": {
316
- "arrows": {
317
- "to": {
318
- "enabled": true
319
- }
320
- },
321
- "color": {
322
- "inherit": true
323
- },
324
- "smooth": false
325
- },
326
- "physics": {
327
- "minVelocity": 0.1
328
- }
329
- }
330
- """
331
- )
332
-
333
- # Save the graph as an interactive HTML file if output_file is provided
334
- if output_file:
335
- net.save_graph(output_file)
336
- print(f"Graph saved to {output_file}")
337
-
338
- # Display the graph inline if in a notebook
339
- else:
340
- net.save_graph("keyword_graph.html")
341
-
342
-
343
- def dense_rank(series):
344
- """
345
- Compute dense rank for a series.
346
- This will assign the same rank to tied values, but ranks will be continuous.
347
- """
348
- return stats.rankdata(series, method="dense")
349
-
350
-
351
- def create_keyword_sentiment_df_simple(df):
352
- sentiment_counts = df["sentiment"].value_counts()
353
- total_positive = sentiment_counts.get("positive", 0)
354
- total_negative = sentiment_counts.get("negative", 0)
355
-
356
- keyword_sentiments = {}
357
-
358
- for _, row in df.iterrows():
359
- sentiment = row["sentiment"]
360
- if sentiment == "neutral":
361
- continue
362
- for word in row["keywords"]:
363
- if word not in keyword_sentiments:
364
- keyword_sentiments[word] = {"positive": 0, "negative": 0}
365
- keyword_sentiments[word][sentiment] += 1
366
-
367
- result_df = pd.DataFrame(
368
- [
369
- {
370
- "word": word,
371
- "sentiment_score": (
372
- counts["positive"] / total_positive if total_positive else 0
373
- )
374
- - (counts["negative"] / total_negative if total_negative else 0),
375
- }
376
- for word, counts in keyword_sentiments.items()
377
- ]
378
- )
379
-
380
- return result_df
381
-
382
-
383
- def create_keyword_sentiment_df(df):
384
- sentiment_counts = df["sentiment"].value_counts()
385
- total_positive = sentiment_counts.get("positive", 0)
386
- total_negative = sentiment_counts.get("negative", 0)
387
-
388
- positive_keywords = {}
389
- negative_keywords = {}
390
-
391
- for _, row in df.iterrows():
392
- if row["sentiment"] == "positive":
393
- for word in row["keywords"]:
394
- positive_keywords[word] = positive_keywords.get(word, 0) + 1
395
- elif row["sentiment"] == "negative":
396
- for word in row["keywords"]:
397
- negative_keywords[word] = negative_keywords.get(word, 0) + 1
398
-
399
- all_keywords = set(positive_keywords.keys()) | set(negative_keywords.keys())
400
-
401
- result_df = pd.DataFrame(
402
- {
403
- "word": list(all_keywords),
404
- "sentiment_positive": [
405
- positive_keywords.get(word, 0) / total_positive if total_positive else 0
406
- for word in all_keywords
407
- ],
408
- "sentiment_negative": [
409
- negative_keywords.get(word, 0) / total_negative if total_negative else 0
410
- for word in all_keywords
411
- ],
412
- }
413
- )
414
-
415
- # Apply dense ranking
416
- result_df["sentiment_positive_rank"] = dense_rank(result_df["sentiment_positive"])
417
- result_df["sentiment_negative_rank"] = dense_rank(result_df["sentiment_negative"])
418
-
419
- # Normalize ranks to [0, 1] range
420
- result_df["sentiment_positive_scaled"] = (
421
- result_df["sentiment_positive_rank"] - 1
422
- ) / (result_df["sentiment_positive_rank"].max() - 1)
423
- result_df["sentiment_negative_scaled"] = (
424
- result_df["sentiment_negative_rank"] - 1
425
- ) / (result_df["sentiment_negative_rank"].max() - 1)
426
-
427
- # Add small random jitter to avoid perfect overlaps
428
- jitter = 0.01
429
- result_df["sentiment_positive_jittered"] = result_df[
430
- "sentiment_positive_scaled"
431
- ] + np.random.uniform(-jitter, jitter, len(result_df))
432
- result_df["sentiment_negative_jittered"] = result_df[
433
- "sentiment_negative_scaled"
434
- ] + np.random.uniform(-jitter, jitter, len(result_df))
435
-
436
- # Add color coding
437
- result_df["color"] = np.where(
438
- result_df["sentiment_positive_jittered"]
439
- > result_df["sentiment_negative_jittered"],
440
- "blue",
441
- "red",
442
- )
443
-
444
- return result_df
445
-
446
-
447
- def create_sentiment_color_mapping(sentiment_df):
448
- """
449
- Create a dictionary mapping keywords to normalized sentiment scores.
450
- """
451
- scaler = MinMaxScaler(feature_range=(-1, 1))
452
- normalized_scores = scaler.fit_transform(sentiment_df[["sentiment_score"]])
453
- return dict(zip(sentiment_df["word"], normalized_scores.flatten()))
454
-
455
-
456
- def create_keyword_graph(
457
- df: pd.DataFrame, keyword_column: str, node_color_mapping: dict = None
458
- ) -> nx.DiGraph:
459
- """
460
- Create a NetworkX DiGraph from a DataFrame column containing lists of keywords.
461
-
462
- Parameters:
463
- df (pd.DataFrame): The input DataFrame
464
- keyword_column (str): The name of the column containing keyword lists
465
- node_color_mapping (dict): Optional dictionary mapping keywords to color values
466
-
467
- Returns:
468
- nx.DiGraph: A directed graph representing keyword connections
469
- """
470
- G = nx.DiGraph()
471
-
472
- for keywords in df[keyword_column]:
473
- if not isinstance(keywords, list) or len(keywords) == 0:
474
- continue
475
-
476
- for i, keyword in enumerate(keywords):
477
- keyword = keyword.encode("utf-8").decode("utf-8")
478
-
479
- if G.has_node(keyword):
480
- G.nodes[keyword]["node_count"] += 1
481
- else:
482
- G.add_node(keyword, node_count=1, title=keyword)
483
- if node_color_mapping and keyword in node_color_mapping:
484
- G.nodes[keyword]["color_value"] = node_color_mapping[keyword]
485
-
486
- if len(keywords) < 2:
487
- continue
488
- for i, keyword in enumerate(keywords):
489
- keyword = keyword.encode("utf-8").decode("utf-8")
490
- if i < len(keywords) - 1:
491
- next_keyword = keywords[i + 1].encode("utf-8").decode("utf-8")
492
- if G.has_edge(keyword, next_keyword):
493
- G[keyword][next_keyword]["edge_count"] += 1
494
- else:
495
- G.add_edge(keyword, next_keyword, edge_count=1)
496
-
497
- return G
498
-
499
-
500
- def visualize_keyword_graph(
501
- graph: nx.DiGraph,
502
- output_file: str = None,
503
- min_edge_count: int = 4,
504
- min_node_count: int = 4,
505
- colormap: str = "RdYlBu",
506
- ):
507
- """
508
- Visualize a filtered NetworkX DiGraph using PyViz Network.
509
-
510
- Parameters:
511
- graph (nx.DiGraph): The input graph to visualize
512
- output_file (str): The name of the output HTML file (default: None)
513
- min_edge_count (int): Minimum edge count to include in the visualization (default: 4)
514
- min_node_count (int): Minimum node count to include in the visualization (default: 4)
515
- colormap (str): Name of the matplotlib colormap to use (default: 'RdYlBu')
516
- """
517
- # Filter nodes based on min_node_count
518
- nodes_to_keep = [
519
- node
520
- for node, data in graph.nodes(data=True)
521
- if data["node_count"] >= min_node_count
522
- ]
523
- filtered_graph = graph.subgraph(nodes_to_keep).copy()
524
-
525
- # Filter edges based on min_edge_count
526
- edges_to_remove = [
527
- (u, v)
528
- for u, v, data in filtered_graph.edges(data=True)
529
- if data["edge_count"] < min_edge_count
530
- ]
531
- filtered_graph.remove_edges_from(edges_to_remove)
532
-
533
- # Remove isolated nodes
534
- filtered_graph.remove_nodes_from(list(nx.isolates(filtered_graph)))
535
-
536
- net = Network(directed=True, width="100%", height="800px")
537
-
538
- cmap = plt.get_cmap(colormap)
539
-
540
- for node, data in filtered_graph.nodes(data=True):
541
- if "color_value" in data:
542
- # Map the color_value from [-1, 1] to [0, 1] for the colormap
543
- color_val = (data["color_value"] + 1) / 2
544
- node_color = mcolors.rgb2hex(cmap(color_val))
545
- else:
546
- node_color = None
547
-
548
- net.add_node(
549
- node,
550
- label=data["title"],
551
- title=f"Keyword: {data['title']}\nCount: {data['node_count']}\nSentiment: {data.get('color_value', 0):.2f}",
552
- size=data["node_count"] * 1,
553
- color=node_color,
554
- )
555
-
556
- for source, target, data in filtered_graph.edges(data=True):
557
- net.add_edge(
558
- source,
559
- target,
560
- title=f"Count: {data['edge_count']}",
561
- width=data["edge_count"],
562
- )
563
-
564
- net.set_options(
565
- """
566
- var options = {
567
- "nodes": {
568
- "font": {
569
- "size": 12
570
- }
571
- },
572
- "edges": {
573
- "arrows": {
574
- "to": {
575
- "enabled": true
576
- }
577
- },
578
- "color": {
579
- "inherit": true
580
- },
581
- "smooth": false
582
- },
583
- "physics": {
584
- "barnesHut": {
585
- "gravitationalConstant": -2000,
586
- "centralGravity": 0.3,
587
- "springLength": 95
588
- },
589
- "minVelocity": 0.75
590
- }
591
- }
592
- """
593
- )
594
-
595
- if output_file:
596
- net.save_graph(output_file)
597
- print(f"Graph saved to {output_file}")
598
- else:
599
- net.save_graph("keyword_graph.html")
600
-
601
-
602
- def visualize_keyword_graph_force(
603
- graph: nx.DiGraph,
604
- output_file: str = None,
605
- min_edge_count: int = 4,
606
- min_node_count: int = 4,
607
- colormap: str = "RdYlBu",
608
- canvas_width: int = 1000,
609
- canvas_height: int = 800,
610
- ):
611
- """
612
- Visualize a filtered NetworkX DiGraph using PyViz Network with sentiment-based positioning.
613
-
614
- Parameters:
615
- graph (nx.DiGraph): The input graph to visualize
616
- output_file (str): The name of the output HTML file (default: None)
617
- min_edge_count (int): Minimum edge count to include in the visualization (default: 4)
618
- min_node_count (int): Minimum node count to include in the visualization (default: 4)
619
- colormap (str): Name of the matplotlib colormap to use (default: 'RdYlBu')
620
- canvas_width (int): Width of the canvas in pixels (default: 1000)
621
- canvas_height (int): Height of the canvas in pixels (default: 800)
622
- """
623
- # Filter nodes based on min_node_count
624
- nodes_to_keep = [
625
- node
626
- for node, data in graph.nodes(data=True)
627
- if data["node_count"] >= min_node_count
628
- ]
629
- filtered_graph = graph.subgraph(nodes_to_keep).copy()
630
-
631
- # Filter edges based on min_edge_count
632
- edges_to_remove = [
633
- (u, v)
634
- for u, v, data in filtered_graph.edges(data=True)
635
- if data["edge_count"] < min_edge_count
636
- ]
637
- filtered_graph.remove_edges_from(edges_to_remove)
638
-
639
- # Remove isolated nodes
640
- filtered_graph.remove_nodes_from(list(nx.isolates(filtered_graph)))
641
-
642
- net = Network(directed=True, width=f"{canvas_width}px", height=f"{canvas_height}px")
643
-
644
- cmap = plt.get_cmap(colormap)
645
-
646
- for node, data in filtered_graph.nodes(data=True):
647
- if "color_value" in data:
648
- # Map the color_value from [-1, 1] to [0, 1] for the colormap
649
- color_val = (data["color_value"] + 1) / 2
650
- node_color = mcolors.rgb2hex(cmap(color_val))
651
-
652
- # Set x position based on sentiment (color_value)
653
- x_pos = int((data["color_value"] + 1) * canvas_width / 2)
654
- else:
655
- node_color = None
656
- x_pos = canvas_width // 2 # Neutral position for nodes without sentiment
657
-
658
- net.add_node(
659
- node,
660
- label=data["title"],
661
- title=f"Keyword: {data['title']}\nCount: {data['node_count']}\nSentiment: {data.get('color_value', 0):.2f}",
662
- size=data["node_count"] * 1,
663
- color=node_color,
664
- x=x_pos,
665
- y=None,
666
- ) # Let y be determined by the physics engine
667
-
668
- for source, target, data in filtered_graph.edges(data=True):
669
- net.add_edge(
670
- source,
671
- target,
672
- title=f"Count: {data['edge_count']}",
673
- width=data["edge_count"],
674
- )
675
-
676
- net.set_options(
677
- f"""
678
- var options = {{
679
- "nodes": {{
680
- "font": {{
681
- "size": 12
682
- }}
683
- }},
684
- "edges": {{
685
- "arrows": {{
686
- "to": {{
687
- "enabled": true
688
- }}
689
- }},
690
- "color": {{
691
- "inherit": true
692
- }},
693
- "smooth": false
694
- }},
695
- "physics": {{
696
- "barnesHut": {{
697
- "gravitationalConstant": -2000,
698
- "centralGravity": 0.3,
699
- "springLength": 95
700
- }},
701
- "minVelocity": 0.75
702
- }},
703
- "layout": {{
704
- "randomSeed": 42
705
- }}
706
- }}
707
- """
708
- )
709
-
710
- if output_file:
711
- net.save_graph(output_file)
712
- print(f"Graph saved to {output_file}")
713
- else:
714
- net.save_graph("keyword_graph.html")
715
-
716
-
717
- def plot_word_sentiment(df):
718
- # Create the scatter plot
719
- scatter = (
720
- alt.Chart(df)
721
- .mark_circle()
722
- .encode(
723
- x=alt.X(
724
- "sentiment_positive_jittered:Q",
725
- title="Positive Sentiment (Dense Rank)",
726
- scale=alt.Scale(domain=[-0.05, 1.05]),
727
- ),
728
- y=alt.Y(
729
- "sentiment_negative_jittered:Q",
730
- title="Negative Sentiment (Dense Rank)",
731
- scale=alt.Scale(domain=[-0.05, 1.05]),
732
- ),
733
- color=alt.Color(
734
- "color:N",
735
- scale=alt.Scale(domain=["blue", "red"], range=["blue", "red"]),
736
- ),
737
- tooltip=["word", "sentiment_positive", "sentiment_negative"],
738
- )
739
- )
740
-
741
- # Create the text labels
742
- text = scatter.mark_text(align="left", baseline="middle", dx=7).encode(text="word")
743
-
744
- # Create the y=x reference line
745
- line = (
746
- alt.Chart(pd.DataFrame({"x": [0, 1]}))
747
- .mark_line(color="green", strokeDash=[4, 4])
748
- .encode(x="x", y="x")
749
- )
750
-
751
- # Combine the scatter plot, text labels, and reference line
752
- chart = (
753
- (scatter + text + line)
754
- .properties(
755
- width=600, height=600, title="Word Sentiment Analysis (Dense Rank Scaling)"
756
- )
757
- .interactive()
758
- )
759
-
760
- return chart
1
+ import textwrap
2
+ from typing import List, Tuple
3
+
4
+ import altair as alt
5
+ import pandas as pd
6
+
7
+
8
+ def cluster_heatmap_plot(df: pd.DataFrame, x: str, y: List[str], max_width: int = 75):
9
+ """
10
+ Create a heatmap visualization of Likert scale responses grouped by clusters.
11
+
12
+ This function generates an interactive Altair visualization showing the distribution
13
+ of positive and negative responses across different clusters for each question.
14
+ The visualization consists of two parts:
15
+ 1. A bar chart showing the number of respondents in each cluster
16
+ 2. A heatmap showing the sentiment distribution for each question by cluster
17
+
18
+ Parameters
19
+ ----------
20
+ df : pd.DataFrame
21
+ The DataFrame containing the clustered data and encoded Likert responses.
22
+ Should include a cluster column and encoded Likert columns.
23
+
24
+ x : str
25
+ The name of the column containing cluster IDs (e.g., 'question_cluster_id').
26
+
27
+ y : List[str]
28
+ List of column names containing the encoded Likert responses.
29
+ These should typically be columns with values -1, 0, 1 representing
30
+ negative, neutral, and positive responses.
31
+
32
+ max_width : int, default=75
33
+ Maximum width for wrapping question labels in the visualization.
34
+
35
+ Returns
36
+ -------
37
+ alt.VConcatChart
38
+ An Altair chart object combining a bar chart of cluster sizes and
39
+ a heatmap of sentiment distribution that can be displayed in a Jupyter notebook
40
+ or exported as HTML.
41
+
42
+ Notes
43
+ -----
44
+ The function color-codes the heatmap cells based on the percentage of
45
+ positive and negative responses, with green representing positive sentiment,
46
+ red representing negative sentiment, and varying shades for mixed responses.
47
+
48
+ The encoded Likert columns (y parameter) should contain values that are encoded as:
49
+ * 1 for positive responses
50
+ * 0 for neutral responses
51
+ * -1 for negative responses
52
+
53
+ Examples
54
+ --------
55
+ >>> # Assuming df has been processed with cluster_questions
56
+ >>> likert_columns = [f"likert_encoded_{q}" for q in questions]
57
+ >>> heatmap = cluster_heatmap_plot(df, x="question_cluster_id", y=likert_columns)
58
+ >>> display(heatmap)
59
+ """
60
+ # Convert -1, 0, 1 to percent positive and percent negative
61
+ df_positive = df[y].apply(lambda col: (col == 1).astype(int))
62
+ df_negative = df[y].apply(lambda col: (col == -1).astype(int))
63
+
64
+ # Calculate average percent positive and negative for each cluster and question
65
+ heatmap_data_pos = (
66
+ df_positive.groupby(df[x])
67
+ .mean()
68
+ .reset_index()
69
+ .melt(id_vars=x, var_name="question", value_name="percent_positive")
70
+ )
71
+ heatmap_data_neg = (
72
+ df_negative.groupby(df[x])
73
+ .mean()
74
+ .reset_index()
75
+ .melt(id_vars=x, var_name="question", value_name="percent_negative")
76
+ )
77
+
78
+ # Merge positive and negative data
79
+ heatmap_data = pd.merge(heatmap_data_pos, heatmap_data_neg, on=[x, "question"])
80
+ heatmap_data["percent_neutral"] = (
81
+ 1 - heatmap_data["percent_positive"] - heatmap_data["percent_negative"]
82
+ )
83
+
84
+ # Calculate overall positivity for each cluster
85
+ cluster_positivity = (
86
+ heatmap_data.groupby(x)["percent_positive"].mean().sort_values(ascending=False)
87
+ )
88
+ cluster_order = cluster_positivity.index.tolist()
89
+
90
+ # Replace underscores with spaces in question labels
91
+ heatmap_data["question"] = (
92
+ heatmap_data["question"].str.replace("_", " ").str.replace("likert encoded", "")
93
+ )
94
+
95
+ # Wrap long question labels
96
+ wrapped_labels = [
97
+ textwrap.fill(label, width=max_width)
98
+ for label in heatmap_data["question"].unique()
99
+ ]
100
+ label_to_wrapped = dict(zip(heatmap_data["question"].unique(), wrapped_labels))
101
+ heatmap_data["wrapped_question"] = heatmap_data["question"].map(label_to_wrapped)
102
+
103
+ # Define color scale based on percent positive and percent negative
104
+ def get_color(pos: float, neg: float) -> Tuple[str, str]:
105
+ if pos > 0.6:
106
+ return "#1a9641", "white" # Strong positive (green)
107
+ elif pos > 0.4:
108
+ return "#a6d96a", "black" # Moderate positive (light green)
109
+ elif pos > neg:
110
+ return "#ffffbf", "black" # Slightly positive (light yellow)
111
+ elif neg > 0.6:
112
+ return "#d7191c", "white" # Strong negative (red)
113
+ elif neg > 0.4:
114
+ return "#fdae61", "black" # Moderate negative (orange)
115
+ elif neg > pos:
116
+ return "#f4a582", "black" # Slightly negative (light red)
117
+ else:
118
+ return "#f7f7f7", "black" # Neutral (light gray)
119
+
120
+ heatmap_data["background_color"], heatmap_data["text_color"] = zip(
121
+ *heatmap_data.apply(
122
+ lambda row: get_color(row["percent_positive"], row["percent_negative"]),
123
+ axis=1,
124
+ )
125
+ )
126
+
127
+ # Calculate chart dimensions
128
+ chart_width = 600
129
+ row_height = 30
130
+ heatmap_height = len(wrapped_labels) * row_height
131
+ bar_chart_height = 100
132
+
133
+ # Create heatmap
134
+ heatmap = (
135
+ alt.Chart(heatmap_data)
136
+ .mark_rect()
137
+ .encode(
138
+ x=alt.X(f"{x}:O", title="Cluster ID", sort=cluster_order),
139
+ y=alt.Y("wrapped_question:O", title=None, sort=wrapped_labels),
140
+ color=alt.Color("background_color:N", scale=None),
141
+ tooltip=[
142
+ alt.Tooltip(f"{x}:O", title="Cluster ID"),
143
+ alt.Tooltip("question:O", title="Question"),
144
+ alt.Tooltip("percent_positive:Q", title="% Positive", format=".2%"),
145
+ alt.Tooltip("percent_negative:Q", title="% Negative", format=".2%"),
146
+ alt.Tooltip("percent_neutral:Q", title="% Neutral", format=".2%"),
147
+ ],
148
+ )
149
+ .properties(
150
+ width=chart_width,
151
+ height=heatmap_height,
152
+ title="Cluster Heatmap: Sentiment Distribution",
153
+ )
154
+ )
155
+
156
+ # Add text labels to heatmap
157
+ text = heatmap.mark_text(baseline="middle").encode(
158
+ text=alt.Text("percent_positive:Q", format=".0%"),
159
+ color=alt.Color("text_color:N", scale=None),
160
+ )
161
+
162
+ # Create bar chart for cluster counts
163
+ cluster_counts = df[x].value_counts().reset_index()
164
+ cluster_counts.columns = [x, "count"]
165
+ cluster_counts[x] = pd.Categorical(
166
+ cluster_counts[x], categories=cluster_order, ordered=True
167
+ )
168
+ cluster_counts = cluster_counts.sort_values(x)
169
+
170
+ bar_chart = (
171
+ alt.Chart(cluster_counts)
172
+ .mark_bar()
173
+ .encode(
174
+ x=alt.X(f"{x}:O", title="Cluster ID", sort=cluster_order),
175
+ y=alt.Y("count:Q", title="Count"),
176
+ tooltip=[
177
+ alt.Tooltip(f"{x}:O", title="Cluster ID"),
178
+ alt.Tooltip("count:Q", title="Count"),
179
+ ],
180
+ )
181
+ .properties(width=chart_width, height=bar_chart_height, title="Cluster Sizes")
182
+ )
183
+
184
+ # Add text labels to bar chart
185
+ bar_text = bar_chart.mark_text(align="center", baseline="bottom", dy=-5).encode(
186
+ text="count:Q"
187
+ )
188
+
189
+ # Combine bar chart and heatmap using vconcat
190
+ combined_chart = (
191
+ alt.vconcat((bar_chart + bar_text), (heatmap + text))
192
+ .configure_view(strokeWidth=0)
193
+ .configure_axis(
194
+ labelLimit=350 # Increase label limit to show full wrapped text
195
+ )
196
+ )
197
+
198
+ return combined_chart