broadlistening 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/.rubocop.yml +3 -0
- data/CHANGELOG.md +40 -0
- data/CLAUDE.md +112 -0
- data/LICENSE +24 -0
- data/LICENSE-AGPLv3.txt +661 -0
- data/README.md +195 -0
- data/Rakefile +77 -0
- data/exe/broadlistening +6 -0
- data/lib/broadlistening/argument.rb +136 -0
- data/lib/broadlistening/cli.rb +196 -0
- data/lib/broadlistening/comment.rb +128 -0
- data/lib/broadlistening/compatibility.rb +375 -0
- data/lib/broadlistening/config.rb +190 -0
- data/lib/broadlistening/context.rb +180 -0
- data/lib/broadlistening/csv_loader.rb +109 -0
- data/lib/broadlistening/hierarchical_clustering.rb +142 -0
- data/lib/broadlistening/kmeans.rb +185 -0
- data/lib/broadlistening/llm_client.rb +84 -0
- data/lib/broadlistening/pipeline.rb +129 -0
- data/lib/broadlistening/planner.rb +114 -0
- data/lib/broadlistening/provider.rb +97 -0
- data/lib/broadlistening/spec_loader.rb +86 -0
- data/lib/broadlistening/status.rb +132 -0
- data/lib/broadlistening/steps/aggregation.rb +228 -0
- data/lib/broadlistening/steps/base_step.rb +42 -0
- data/lib/broadlistening/steps/clustering.rb +103 -0
- data/lib/broadlistening/steps/embedding.rb +40 -0
- data/lib/broadlistening/steps/extraction.rb +73 -0
- data/lib/broadlistening/steps/initial_labelling.rb +85 -0
- data/lib/broadlistening/steps/merge_labelling.rb +93 -0
- data/lib/broadlistening/steps/overview.rb +36 -0
- data/lib/broadlistening/version.rb +5 -0
- data/lib/broadlistening.rb +44 -0
- data/schema/hierarchical_result.json +152 -0
- data/sig/broadlistening.rbs +4 -0
- metadata +194 -0
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "pathname"
|
|
4
|
+
require "fileutils"
|
|
5
|
+
require "json"
|
|
6
|
+
|
|
7
|
+
module Broadlistening
|
|
8
|
+
# Manages pipeline execution context - all data flowing through the pipeline.
|
|
9
|
+
#
|
|
10
|
+
# The Context holds all intermediate results and provides methods for
|
|
11
|
+
# loading from / saving to disk for incremental execution support.
|
|
12
|
+
#
|
|
13
|
+
# @example Creating a new context
|
|
14
|
+
# context = Context.new
|
|
15
|
+
# context.comments = [Comment.new(...), ...]
|
|
16
|
+
# context.save_step(:extraction, output_dir)
|
|
17
|
+
#
|
|
18
|
+
# @example Loading from existing output
|
|
19
|
+
# context = Context.load_from_dir("/path/to/output")
|
|
20
|
+
class Context
|
|
21
|
+
attr_accessor :comments, :arguments, :relations,
|
|
22
|
+
:cluster_results, :umap_coords,
|
|
23
|
+
:initial_labels, :labels, :overview, :result,
|
|
24
|
+
:output_dir
|
|
25
|
+
|
|
26
|
+
# Output file mapping for each step
|
|
27
|
+
OUTPUT_FILES = {
|
|
28
|
+
extraction: "extraction.json",
|
|
29
|
+
embedding: "embeddings.json",
|
|
30
|
+
clustering: "clustering.json",
|
|
31
|
+
initial_labelling: "initial_labels.json",
|
|
32
|
+
merge_labelling: "merge_labels.json",
|
|
33
|
+
overview: "overview.json",
|
|
34
|
+
aggregation: "result.json"
|
|
35
|
+
}.freeze
|
|
36
|
+
|
|
37
|
+
# Load existing context from output directory
|
|
38
|
+
#
|
|
39
|
+
# @param output_dir [String, Pathname] Directory containing output files
|
|
40
|
+
# @return [Context] A new context populated with data from output files
|
|
41
|
+
def self.load_from_dir(output_dir)
|
|
42
|
+
context = new
|
|
43
|
+
dir = Pathname.new(output_dir)
|
|
44
|
+
|
|
45
|
+
OUTPUT_FILES.each do |step, filename|
|
|
46
|
+
file = dir / filename
|
|
47
|
+
next unless file.exist?
|
|
48
|
+
|
|
49
|
+
data = JSON.parse(file.read, symbolize_names: true)
|
|
50
|
+
context.send(:merge_step_data, step, data)
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
context
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
def initialize
|
|
57
|
+
@comments = []
|
|
58
|
+
@arguments = []
|
|
59
|
+
@relations = []
|
|
60
|
+
@cluster_results = {}
|
|
61
|
+
@umap_coords = nil
|
|
62
|
+
@initial_labels = {}
|
|
63
|
+
@labels = {}
|
|
64
|
+
@overview = nil
|
|
65
|
+
@result = nil
|
|
66
|
+
@output_dir = nil
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
# Save a step's output to file
|
|
70
|
+
#
|
|
71
|
+
# @param step_name [Symbol] The step name
|
|
72
|
+
# @param output_dir [String, Pathname] Output directory
|
|
73
|
+
def save_step(step_name, output_dir)
|
|
74
|
+
dir = Pathname.new(output_dir)
|
|
75
|
+
filename = OUTPUT_FILES[step_name]
|
|
76
|
+
return unless filename
|
|
77
|
+
|
|
78
|
+
FileUtils.mkdir_p(dir)
|
|
79
|
+
data = extract_step_output(step_name)
|
|
80
|
+
File.write(dir / filename, JSON.pretty_generate(data))
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
# Convert to hash for serialization
|
|
84
|
+
#
|
|
85
|
+
# @return [Hash]
|
|
86
|
+
def to_h
|
|
87
|
+
{
|
|
88
|
+
comments: @comments.map(&:to_h),
|
|
89
|
+
arguments: @arguments.map(&:to_h),
|
|
90
|
+
relations: @relations,
|
|
91
|
+
cluster_results: @cluster_results,
|
|
92
|
+
umap_coords: @umap_coords,
|
|
93
|
+
initial_labels: @initial_labels,
|
|
94
|
+
labels: @labels,
|
|
95
|
+
overview: @overview,
|
|
96
|
+
result: @result
|
|
97
|
+
}
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
private
|
|
101
|
+
|
|
102
|
+
def extract_step_output(step_name)
|
|
103
|
+
case step_name
|
|
104
|
+
when :extraction
|
|
105
|
+
{
|
|
106
|
+
comments: @comments.map(&:to_h),
|
|
107
|
+
arguments: @arguments.map(&:to_h),
|
|
108
|
+
relations: @relations
|
|
109
|
+
}
|
|
110
|
+
when :embedding
|
|
111
|
+
{
|
|
112
|
+
arguments: @arguments.map(&:to_embedding_h)
|
|
113
|
+
}
|
|
114
|
+
when :clustering
|
|
115
|
+
{
|
|
116
|
+
cluster_results: @cluster_results,
|
|
117
|
+
arguments: @arguments.map(&:to_clustering_h)
|
|
118
|
+
}
|
|
119
|
+
when :initial_labelling
|
|
120
|
+
{ initial_labels: @initial_labels }
|
|
121
|
+
when :merge_labelling
|
|
122
|
+
{ labels: @labels }
|
|
123
|
+
when :overview
|
|
124
|
+
{ overview: @overview }
|
|
125
|
+
when :aggregation
|
|
126
|
+
@result
|
|
127
|
+
end
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
def merge_step_data(step_name, data)
|
|
131
|
+
case step_name
|
|
132
|
+
when :extraction
|
|
133
|
+
load_extraction_data(data)
|
|
134
|
+
when :embedding
|
|
135
|
+
merge_embedding_data(data)
|
|
136
|
+
when :clustering
|
|
137
|
+
merge_clustering_data(data)
|
|
138
|
+
when :initial_labelling
|
|
139
|
+
@initial_labels = data[:initial_labels] if data[:initial_labels]
|
|
140
|
+
when :merge_labelling
|
|
141
|
+
@labels = data[:labels] if data[:labels]
|
|
142
|
+
when :overview
|
|
143
|
+
@overview = data[:overview] if data[:overview]
|
|
144
|
+
end
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
def load_extraction_data(data)
|
|
148
|
+
@comments = (data[:comments] || []).map do |c|
|
|
149
|
+
Comment.new(**c.slice(:id, :body, :proposal_id, :source_url, :attributes, :properties))
|
|
150
|
+
end
|
|
151
|
+
@arguments = (data[:arguments] || []).map { |a| Argument.from_hash(a) }
|
|
152
|
+
@relations = data[:relations] if data[:relations]
|
|
153
|
+
end
|
|
154
|
+
|
|
155
|
+
def merge_embedding_data(data)
|
|
156
|
+
return unless data[:arguments]
|
|
157
|
+
|
|
158
|
+
embedding_map = data[:arguments].to_h { |e| [ e[:arg_id], e[:embedding] ] }
|
|
159
|
+
@arguments.each do |arg|
|
|
160
|
+
embedding = embedding_map[arg.arg_id]
|
|
161
|
+
arg.embedding = embedding if embedding
|
|
162
|
+
end
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
def merge_clustering_data(data)
|
|
166
|
+
@cluster_results = data[:cluster_results] if data[:cluster_results]
|
|
167
|
+
return unless data[:arguments]
|
|
168
|
+
|
|
169
|
+
clustering_map = data[:arguments].to_h { |c| [ c[:arg_id], c ] }
|
|
170
|
+
@arguments.each do |arg|
|
|
171
|
+
cluster_data = clustering_map[arg.arg_id]
|
|
172
|
+
next unless cluster_data
|
|
173
|
+
|
|
174
|
+
arg.x = cluster_data[:x]
|
|
175
|
+
arg.y = cluster_data[:y]
|
|
176
|
+
arg.cluster_ids = cluster_data[:cluster_ids]
|
|
177
|
+
end
|
|
178
|
+
end
|
|
179
|
+
end
|
|
180
|
+
end
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "csv"
|
|
4
|
+
|
|
5
|
+
module Broadlistening
|
|
6
|
+
# Loads comments from CSV files with Kouchou-AI compatible format.
|
|
7
|
+
#
|
|
8
|
+
# This loader supports the CSV format used by Kouchou-AI (Python version),
|
|
9
|
+
# enabling compatibility testing between the two implementations.
|
|
10
|
+
#
|
|
11
|
+
# @example Loading a Kouchou-AI format CSV
|
|
12
|
+
# comments = CsvLoader.load("inputs/example-polis.csv")
|
|
13
|
+
#
|
|
14
|
+
# @example Loading with property columns
|
|
15
|
+
# comments = CsvLoader.load("data.csv", property_names: ["agrees", "disagrees"])
|
|
16
|
+
#
|
|
17
|
+
# @example Using custom column mapping
|
|
18
|
+
# comments = CsvLoader.load("custom.csv", column_mapping: {
|
|
19
|
+
# id: "my_id_column",
|
|
20
|
+
# body: "my_body_column"
|
|
21
|
+
# })
|
|
22
|
+
class CsvLoader
|
|
23
|
+
# Default column mapping for Kouchou-AI format
|
|
24
|
+
# Maps Ruby gem's expected keys to Kouchou-AI's CSV column names
|
|
25
|
+
KOUCHOU_AI_COLUMNS = {
|
|
26
|
+
id: "comment-id",
|
|
27
|
+
body: "comment-body",
|
|
28
|
+
source_url: "source-url"
|
|
29
|
+
}.freeze
|
|
30
|
+
|
|
31
|
+
class << self
|
|
32
|
+
# Load comments from a CSV file
|
|
33
|
+
#
|
|
34
|
+
# @param path [String] Path to the CSV file
|
|
35
|
+
# @param property_names [Array<String>] Property column names to extract
|
|
36
|
+
# @param column_mapping [Hash] Custom column name mapping (overrides defaults)
|
|
37
|
+
# @param encoding [String] File encoding (default: UTF-8 with BOM handling)
|
|
38
|
+
# @return [Array<Comment>] Array of Comment objects
|
|
39
|
+
def load(path, property_names: [], column_mapping: {}, encoding: "bom|utf-8")
|
|
40
|
+
mapping = KOUCHOU_AI_COLUMNS.merge(column_mapping)
|
|
41
|
+
|
|
42
|
+
comments = []
|
|
43
|
+
CSV.foreach(path, headers: true, encoding: encoding) do |row|
|
|
44
|
+
comment = build_comment(row, mapping, property_names)
|
|
45
|
+
comments << comment unless comment.nil?
|
|
46
|
+
end
|
|
47
|
+
comments
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
# Load comments from a CSV string
|
|
51
|
+
#
|
|
52
|
+
# @param csv_string [String] CSV content as string
|
|
53
|
+
# @param property_names [Array<String>] Property column names to extract
|
|
54
|
+
# @param column_mapping [Hash] Custom column name mapping
|
|
55
|
+
# @return [Array<Comment>] Array of Comment objects
|
|
56
|
+
def parse(csv_string, property_names: [], column_mapping: {})
|
|
57
|
+
mapping = KOUCHOU_AI_COLUMNS.merge(column_mapping)
|
|
58
|
+
|
|
59
|
+
comments = []
|
|
60
|
+
CSV.parse(csv_string, headers: true) do |row|
|
|
61
|
+
comment = build_comment(row, mapping, property_names)
|
|
62
|
+
comments << comment unless comment.nil?
|
|
63
|
+
end
|
|
64
|
+
comments
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
private
|
|
68
|
+
|
|
69
|
+
def build_comment(row, mapping, property_names)
|
|
70
|
+
id = extract_value(row, mapping[:id], "id")
|
|
71
|
+
body = extract_value(row, mapping[:body], "body")
|
|
72
|
+
|
|
73
|
+
# Skip rows without required fields
|
|
74
|
+
return nil if id.nil? || body.nil? || body.strip.empty?
|
|
75
|
+
|
|
76
|
+
hash = {
|
|
77
|
+
id: id.to_s,
|
|
78
|
+
body: body,
|
|
79
|
+
source_url: extract_value(row, mapping[:source_url], "source_url", "source-url")
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
# Extract attribute columns (attribute_* or attribute-*)
|
|
83
|
+
row.headers.each do |header|
|
|
84
|
+
next if header.nil?
|
|
85
|
+
|
|
86
|
+
if header.start_with?("attribute_") || header.start_with?("attribute-")
|
|
87
|
+
hash[header.to_sym] = row[header]
|
|
88
|
+
end
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
# Extract property columns
|
|
92
|
+
property_names.each do |prop_name|
|
|
93
|
+
value = row[prop_name] || row[prop_name.to_s.tr("_", "-")]
|
|
94
|
+
hash[prop_name.to_sym] = value if value
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
Comment.from_hash(hash, property_names: property_names)
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
def extract_value(row, *possible_names)
|
|
101
|
+
possible_names.compact.each do |name|
|
|
102
|
+
value = row[name]
|
|
103
|
+
return value unless value.nil?
|
|
104
|
+
end
|
|
105
|
+
nil
|
|
106
|
+
end
|
|
107
|
+
end
|
|
108
|
+
end
|
|
109
|
+
end
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Broadlistening
|
|
4
|
+
class HierarchicalClustering
|
|
5
|
+
# Ward法による階層的クラスタリング
|
|
6
|
+
# scipy.cluster.hierarchy.linkage(method="ward") と同等の実装
|
|
7
|
+
|
|
8
|
+
class << self
|
|
9
|
+
def merge(centroids, labels, target_clusters)
|
|
10
|
+
new(centroids, labels, target_clusters).merge
|
|
11
|
+
end
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def initialize(centroids, labels, target_clusters)
|
|
15
|
+
@centroids = to_numo_array(centroids)
|
|
16
|
+
@labels = labels.dup
|
|
17
|
+
@target_clusters = target_clusters
|
|
18
|
+
@n_original_clusters = @centroids.shape[0]
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def merge
|
|
22
|
+
return @labels if current_cluster_count <= @target_clusters
|
|
23
|
+
|
|
24
|
+
# クラスタ情報を初期化
|
|
25
|
+
# 各クラスタ: {centroid: 重心, size: サイズ, members: 元のクラスタID}
|
|
26
|
+
clusters = initialize_clusters
|
|
27
|
+
|
|
28
|
+
# Ward法で階層的にマージ
|
|
29
|
+
while clusters.size > @target_clusters
|
|
30
|
+
c1_id, c2_id = find_ward_closest_pair(clusters)
|
|
31
|
+
break if c1_id.nil?
|
|
32
|
+
|
|
33
|
+
merge_ward_clusters!(clusters, c1_id, c2_id)
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
# 元のラベルを新しいクラスタIDにマッピング
|
|
37
|
+
build_final_labels(clusters)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
private
|
|
41
|
+
|
|
42
|
+
def to_numo_array(centroids)
|
|
43
|
+
if centroids.is_a?(Numo::DFloat)
|
|
44
|
+
centroids
|
|
45
|
+
else
|
|
46
|
+
Numo::DFloat.cast(centroids)
|
|
47
|
+
end
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
def current_cluster_count
|
|
51
|
+
@labels.uniq.size
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
def initialize_clusters
|
|
55
|
+
clusters = {}
|
|
56
|
+
@n_original_clusters.times do |i|
|
|
57
|
+
clusters[i] = {
|
|
58
|
+
centroid: @centroids[i, true].to_a,
|
|
59
|
+
size: @labels.count(i),
|
|
60
|
+
members: [ i ]
|
|
61
|
+
}
|
|
62
|
+
end
|
|
63
|
+
clusters
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
def find_ward_closest_pair(clusters)
|
|
67
|
+
min_dist = Float::INFINITY
|
|
68
|
+
min_pair = [ nil, nil ]
|
|
69
|
+
|
|
70
|
+
cluster_ids = clusters.keys
|
|
71
|
+
cluster_ids.each_with_index do |c1_id, i|
|
|
72
|
+
cluster_ids[(i + 1)..].each do |c2_id|
|
|
73
|
+
dist = ward_distance(clusters[c1_id], clusters[c2_id])
|
|
74
|
+
if dist < min_dist
|
|
75
|
+
min_dist = dist
|
|
76
|
+
min_pair = [ c1_id, c2_id ]
|
|
77
|
+
end
|
|
78
|
+
end
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
min_pair
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
def ward_distance(cluster1, cluster2)
|
|
85
|
+
# Ward法: マージ時の分散増加量を計算
|
|
86
|
+
# d(i,j) = sqrt(2 * n_i * n_j / (n_i + n_j)) * ||c_i - c_j||
|
|
87
|
+
n1 = cluster1[:size]
|
|
88
|
+
n2 = cluster2[:size]
|
|
89
|
+
c1 = cluster1[:centroid]
|
|
90
|
+
c2 = cluster2[:centroid]
|
|
91
|
+
|
|
92
|
+
# ユークリッド距離の2乗
|
|
93
|
+
dist_sq = c1.zip(c2).sum { |a, b| (a - b)**2 }
|
|
94
|
+
|
|
95
|
+
# Ward距離
|
|
96
|
+
Math.sqrt(2.0 * n1 * n2 / (n1 + n2) * dist_sq)
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
def merge_ward_clusters!(clusters, c1_id, c2_id)
|
|
100
|
+
c1 = clusters[c1_id]
|
|
101
|
+
c2 = clusters[c2_id]
|
|
102
|
+
|
|
103
|
+
# 新しい重心を計算(サイズで重み付け)
|
|
104
|
+
n1 = c1[:size]
|
|
105
|
+
n2 = c2[:size]
|
|
106
|
+
new_size = n1 + n2
|
|
107
|
+
|
|
108
|
+
new_centroid = c1[:centroid].zip(c2[:centroid]).map do |v1, v2|
|
|
109
|
+
(v1 * n1 + v2 * n2) / new_size
|
|
110
|
+
end
|
|
111
|
+
|
|
112
|
+
# マージしたクラスタを作成(小さいIDを使用)
|
|
113
|
+
merged_id = [ c1_id, c2_id ].min
|
|
114
|
+
removed_id = [ c1_id, c2_id ].max
|
|
115
|
+
|
|
116
|
+
clusters[merged_id] = {
|
|
117
|
+
centroid: new_centroid,
|
|
118
|
+
size: new_size,
|
|
119
|
+
members: c1[:members] + c2[:members]
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
clusters.delete(removed_id)
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
def build_final_labels(clusters)
|
|
126
|
+
# 各元クラスタIDから最終クラスタIDへのマッピングを構築
|
|
127
|
+
original_to_final = {}
|
|
128
|
+
clusters.each_value do |cluster|
|
|
129
|
+
final_id = cluster[:members].min
|
|
130
|
+
cluster[:members].each do |original_id|
|
|
131
|
+
original_to_final[original_id] = final_id
|
|
132
|
+
end
|
|
133
|
+
end
|
|
134
|
+
|
|
135
|
+
# 連番に振り直し
|
|
136
|
+
unique_finals = original_to_final.values.uniq.sort
|
|
137
|
+
final_remap = unique_finals.each_with_index.to_h
|
|
138
|
+
|
|
139
|
+
@labels.map { |l| final_remap[original_to_final[l]] }
|
|
140
|
+
end
|
|
141
|
+
end
|
|
142
|
+
end
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Broadlistening
|
|
4
|
+
class KMeans
|
|
5
|
+
attr_reader :centroids, :labels, :n_clusters, :inertia
|
|
6
|
+
|
|
7
|
+
DEFAULT_MAX_ITERATIONS = 100
|
|
8
|
+
DEFAULT_TOLERANCE = 1e-6
|
|
9
|
+
|
|
10
|
+
def initialize(n_clusters:, max_iterations: DEFAULT_MAX_ITERATIONS, random_state: nil, tolerance: DEFAULT_TOLERANCE)
|
|
11
|
+
@n_clusters = n_clusters
|
|
12
|
+
@max_iterations = max_iterations
|
|
13
|
+
@tolerance = tolerance
|
|
14
|
+
@random = random_state ? Random.new(random_state) : Random.new
|
|
15
|
+
@centroids = nil
|
|
16
|
+
@labels = nil
|
|
17
|
+
@inertia = nil
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def fit(data)
|
|
21
|
+
@data = to_numo_array(data)
|
|
22
|
+
validate_data!
|
|
23
|
+
|
|
24
|
+
@centroids = initialize_centroids_pp
|
|
25
|
+
@labels = Array.new(@data.shape[0])
|
|
26
|
+
|
|
27
|
+
@max_iterations.times do
|
|
28
|
+
@labels = assign_labels
|
|
29
|
+
new_centroids = update_centroids
|
|
30
|
+
|
|
31
|
+
if converged?(new_centroids)
|
|
32
|
+
@centroids = new_centroids
|
|
33
|
+
break
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
@centroids = new_centroids
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
@inertia = compute_inertia
|
|
40
|
+
self
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
def predict(data)
|
|
44
|
+
data = to_numo_array(data)
|
|
45
|
+
assign_labels_for(data)
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def fit_predict(data)
|
|
49
|
+
fit(data)
|
|
50
|
+
@labels
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
private
|
|
54
|
+
|
|
55
|
+
def to_numo_array(data)
|
|
56
|
+
return data if data.is_a?(Numo::DFloat)
|
|
57
|
+
|
|
58
|
+
Numo::DFloat.cast(data)
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
def validate_data!
|
|
62
|
+
n_samples = @data.shape[0]
|
|
63
|
+
raise ClusteringError, "n_clusters (#{@n_clusters}) must be <= n_samples (#{n_samples})" if @n_clusters > n_samples
|
|
64
|
+
raise ClusteringError, "n_clusters must be positive" if @n_clusters <= 0
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
def initialize_centroids_pp
|
|
68
|
+
n_samples = @data.shape[0]
|
|
69
|
+
n_features = @data.shape[1]
|
|
70
|
+
centroids = Numo::DFloat.zeros(@n_clusters, n_features)
|
|
71
|
+
|
|
72
|
+
first_idx = @random.rand(n_samples)
|
|
73
|
+
centroids[0, true] = @data[first_idx, true]
|
|
74
|
+
|
|
75
|
+
(1...@n_clusters).each do |k|
|
|
76
|
+
distances = compute_min_distances_to_centroids(centroids[0...k, true])
|
|
77
|
+
probabilities = distances**2
|
|
78
|
+
sum_probs = probabilities.sum
|
|
79
|
+
probabilities /= sum_probs if sum_probs > 0
|
|
80
|
+
|
|
81
|
+
next_idx = weighted_random_choice(probabilities)
|
|
82
|
+
centroids[k, true] = @data[next_idx, true]
|
|
83
|
+
end
|
|
84
|
+
|
|
85
|
+
centroids
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
def compute_min_distances_to_centroids(centroids)
|
|
89
|
+
n_samples = @data.shape[0]
|
|
90
|
+
min_distances = Numo::DFloat.new(n_samples).fill(Float::INFINITY)
|
|
91
|
+
|
|
92
|
+
centroids.shape[0].times do |k|
|
|
93
|
+
distances = compute_distances_to_centroid(@data, centroids[k, true])
|
|
94
|
+
min_distances = Numo::DFloat.minimum(min_distances, distances)
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
min_distances
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
def compute_distances_to_centroid(points, centroid)
|
|
101
|
+
diff = points - centroid
|
|
102
|
+
(diff**2).sum(axis: 1)
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
def weighted_random_choice(probabilities)
|
|
106
|
+
cumsum = 0.0
|
|
107
|
+
threshold = @random.rand
|
|
108
|
+
probs_array = probabilities.to_a
|
|
109
|
+
|
|
110
|
+
probs_array.each_with_index do |prob, idx|
|
|
111
|
+
cumsum += prob
|
|
112
|
+
return idx if cumsum >= threshold
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
probs_array.size - 1
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
def assign_labels
|
|
119
|
+
assign_labels_for(@data)
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
def assign_labels_for(data)
|
|
123
|
+
n_samples = data.shape[0]
|
|
124
|
+
labels = Array.new(n_samples)
|
|
125
|
+
|
|
126
|
+
n_samples.times do |i|
|
|
127
|
+
point = data[i, true]
|
|
128
|
+
min_dist = Float::INFINITY
|
|
129
|
+
min_label = 0
|
|
130
|
+
|
|
131
|
+
@n_clusters.times do |k|
|
|
132
|
+
dist = squared_distance(point, @centroids[k, true])
|
|
133
|
+
if dist < min_dist
|
|
134
|
+
min_dist = dist
|
|
135
|
+
min_label = k
|
|
136
|
+
end
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
labels[i] = min_label
|
|
140
|
+
end
|
|
141
|
+
|
|
142
|
+
labels
|
|
143
|
+
end
|
|
144
|
+
|
|
145
|
+
def squared_distance(a, b)
|
|
146
|
+
((a - b)**2).sum
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
def update_centroids
|
|
150
|
+
n_features = @data.shape[1]
|
|
151
|
+
new_centroids = Numo::DFloat.zeros(@n_clusters, n_features)
|
|
152
|
+
counts = Array.new(@n_clusters, 0)
|
|
153
|
+
|
|
154
|
+
@data.shape[0].times do |i|
|
|
155
|
+
label = @labels[i]
|
|
156
|
+
new_centroids[label, true] += @data[i, true]
|
|
157
|
+
counts[label] += 1
|
|
158
|
+
end
|
|
159
|
+
|
|
160
|
+
@n_clusters.times do |k|
|
|
161
|
+
if counts[k] > 0
|
|
162
|
+
new_centroids[k, true] /= counts[k]
|
|
163
|
+
else
|
|
164
|
+
random_idx = @random.rand(@data.shape[0])
|
|
165
|
+
new_centroids[k, true] = @data[random_idx, true]
|
|
166
|
+
end
|
|
167
|
+
end
|
|
168
|
+
|
|
169
|
+
new_centroids
|
|
170
|
+
end
|
|
171
|
+
|
|
172
|
+
def converged?(new_centroids)
|
|
173
|
+
((@centroids - new_centroids)**2).sum < @tolerance
|
|
174
|
+
end
|
|
175
|
+
|
|
176
|
+
def compute_inertia
|
|
177
|
+
total = 0.0
|
|
178
|
+
@data.shape[0].times do |i|
|
|
179
|
+
label = @labels[i]
|
|
180
|
+
total += squared_distance(@data[i, true], @centroids[label, true])
|
|
181
|
+
end
|
|
182
|
+
total
|
|
183
|
+
end
|
|
184
|
+
end
|
|
185
|
+
end
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Broadlistening
|
|
4
|
+
class LlmClient
|
|
5
|
+
MAX_RETRIES = 3
|
|
6
|
+
RETRY_DELAY = 1
|
|
7
|
+
|
|
8
|
+
def initialize(config)
|
|
9
|
+
@config = config
|
|
10
|
+
@provider = Provider.new(config.provider, local_llm_address: config.local_llm_address)
|
|
11
|
+
@client = @provider.build_openai_client(
|
|
12
|
+
api_key: config.api_key,
|
|
13
|
+
base_url: config.api_base_url,
|
|
14
|
+
azure_api_version: config.azure_api_version
|
|
15
|
+
)
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def chat(system:, user:, json_mode: false)
|
|
19
|
+
params = build_chat_params(system, user, json_mode)
|
|
20
|
+
response = with_retry { @client.chat(parameters: params) }
|
|
21
|
+
extract_chat_content(response)
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def embed(texts)
|
|
25
|
+
texts = [ texts ] if texts.is_a?(String)
|
|
26
|
+
response = with_retry do
|
|
27
|
+
@client.embeddings(
|
|
28
|
+
parameters: {
|
|
29
|
+
model: @config.embedding_model,
|
|
30
|
+
input: texts
|
|
31
|
+
}
|
|
32
|
+
)
|
|
33
|
+
end
|
|
34
|
+
extract_embeddings(response)
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
private
|
|
38
|
+
|
|
39
|
+
def build_chat_params(system, user, json_mode)
|
|
40
|
+
params = {
|
|
41
|
+
model: @config.model,
|
|
42
|
+
messages: [
|
|
43
|
+
{ role: "system", content: system },
|
|
44
|
+
{ role: "user", content: user }
|
|
45
|
+
]
|
|
46
|
+
}
|
|
47
|
+
params[:response_format] = { type: "json_object" } if json_mode
|
|
48
|
+
params
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
def extract_chat_content(response)
|
|
52
|
+
validate_response!(response)
|
|
53
|
+
response.dig("choices", 0, "message", "content")
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
def extract_embeddings(response)
|
|
57
|
+
validate_response!(response)
|
|
58
|
+
response["data"].sort_by { |d| d["index"] }.map { |d| d["embedding"] }
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
def validate_response!(response)
|
|
62
|
+
return if response.is_a?(Hash) && !response["error"]
|
|
63
|
+
|
|
64
|
+
error_message = response.is_a?(Hash) ? response.dig("error", "message") : "Unknown error"
|
|
65
|
+
raise LlmError, "LLM API error: #{error_message}"
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
def with_retry
|
|
69
|
+
retries = 0
|
|
70
|
+
begin
|
|
71
|
+
yield
|
|
72
|
+
rescue Faraday::ClientError => e
|
|
73
|
+
raise LlmError, "LLM API error: #{e.message}"
|
|
74
|
+
rescue Faraday::ServerError, Faraday::ConnectionFailed, Faraday::TimeoutError,
|
|
75
|
+
Net::OpenTimeout, Errno::ECONNRESET => e
|
|
76
|
+
retries += 1
|
|
77
|
+
raise LlmError, "LLM API request failed after #{MAX_RETRIES} retries: #{e.message}" if retries > MAX_RETRIES
|
|
78
|
+
|
|
79
|
+
sleep(RETRY_DELAY * retries)
|
|
80
|
+
retry
|
|
81
|
+
end
|
|
82
|
+
end
|
|
83
|
+
end
|
|
84
|
+
end
|