mseep-rmcp 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_rmcp-0.3.3.dist-info/METADATA +50 -0
- mseep_rmcp-0.3.3.dist-info/RECORD +34 -0
- mseep_rmcp-0.3.3.dist-info/WHEEL +5 -0
- mseep_rmcp-0.3.3.dist-info/entry_points.txt +2 -0
- mseep_rmcp-0.3.3.dist-info/licenses/LICENSE +21 -0
- mseep_rmcp-0.3.3.dist-info/top_level.txt +1 -0
- rmcp/__init__.py +31 -0
- rmcp/cli.py +317 -0
- rmcp/core/__init__.py +14 -0
- rmcp/core/context.py +150 -0
- rmcp/core/schemas.py +156 -0
- rmcp/core/server.py +261 -0
- rmcp/r_assets/__init__.py +8 -0
- rmcp/r_integration.py +112 -0
- rmcp/registries/__init__.py +26 -0
- rmcp/registries/prompts.py +316 -0
- rmcp/registries/resources.py +266 -0
- rmcp/registries/tools.py +223 -0
- rmcp/scripts/__init__.py +9 -0
- rmcp/security/__init__.py +15 -0
- rmcp/security/vfs.py +233 -0
- rmcp/tools/descriptive.py +279 -0
- rmcp/tools/econometrics.py +250 -0
- rmcp/tools/fileops.py +315 -0
- rmcp/tools/machine_learning.py +299 -0
- rmcp/tools/regression.py +287 -0
- rmcp/tools/statistical_tests.py +332 -0
- rmcp/tools/timeseries.py +239 -0
- rmcp/tools/transforms.py +293 -0
- rmcp/tools/visualization.py +590 -0
- rmcp/transport/__init__.py +16 -0
- rmcp/transport/base.py +130 -0
- rmcp/transport/jsonrpc.py +243 -0
- rmcp/transport/stdio.py +201 -0
rmcp/tools/fileops.py
ADDED
@@ -0,0 +1,315 @@
|
|
1
|
+
"""
|
2
|
+
File operations tools for RMCP.
|
3
|
+
|
4
|
+
Data import, export, and file manipulation capabilities.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Dict, Any
|
8
|
+
from ..registries.tools import tool
|
9
|
+
from ..core.schemas import table_schema
|
10
|
+
from ..r_integration import execute_r_script
|
11
|
+
|
12
|
+
|
13
|
+
@tool(
|
14
|
+
name="read_csv",
|
15
|
+
input_schema={
|
16
|
+
"type": "object",
|
17
|
+
"properties": {
|
18
|
+
"file_path": {"type": "string"},
|
19
|
+
"header": {"type": "boolean", "default": True},
|
20
|
+
"sep": {"type": "string", "default": ","},
|
21
|
+
"na_strings": {"type": "array", "items": {"type": "string"}, "default": ["", "NA", "NULL"]},
|
22
|
+
"skip_rows": {"type": "integer", "minimum": 0, "default": 0},
|
23
|
+
"max_rows": {"type": "integer", "minimum": 1}
|
24
|
+
},
|
25
|
+
"required": ["file_path"]
|
26
|
+
},
|
27
|
+
description="Read CSV files with flexible parsing options"
|
28
|
+
)
|
29
|
+
async def read_csv(context, params):
|
30
|
+
"""Read CSV file and return data."""
|
31
|
+
|
32
|
+
await context.info("Reading CSV file", file_path=params.get("file_path"))
|
33
|
+
|
34
|
+
r_script = '''
|
35
|
+
file_path <- args$file_path
|
36
|
+
header <- args$header %||% TRUE
|
37
|
+
sep <- args$sep %||% ","
|
38
|
+
na_strings <- args$na_strings %||% c("", "NA", "NULL")
|
39
|
+
skip_rows <- args$skip_rows %||% 0
|
40
|
+
max_rows <- args$max_rows
|
41
|
+
|
42
|
+
# Check if file exists
|
43
|
+
if (!file.exists(file_path)) {
|
44
|
+
stop(paste("File not found:", file_path))
|
45
|
+
}
|
46
|
+
|
47
|
+
# Read CSV
|
48
|
+
if (!is.null(max_rows)) {
|
49
|
+
data <- read.csv(file_path, header = header, sep = sep,
|
50
|
+
na.strings = na_strings, skip = skip_rows, nrows = max_rows)
|
51
|
+
} else {
|
52
|
+
data <- read.csv(file_path, header = header, sep = sep,
|
53
|
+
na.strings = na_strings, skip = skip_rows)
|
54
|
+
}
|
55
|
+
|
56
|
+
# Data summary
|
57
|
+
numeric_vars <- names(data)[sapply(data, is.numeric)]
|
58
|
+
character_vars <- names(data)[sapply(data, is.character)]
|
59
|
+
factor_vars <- names(data)[sapply(data, is.factor)]
|
60
|
+
|
61
|
+
result <- list(
|
62
|
+
data = data,
|
63
|
+
file_info = list(
|
64
|
+
file_path = file_path,
|
65
|
+
n_rows = nrow(data),
|
66
|
+
n_cols = ncol(data),
|
67
|
+
column_names = names(data),
|
68
|
+
numeric_variables = numeric_vars,
|
69
|
+
character_variables = character_vars,
|
70
|
+
factor_variables = factor_vars
|
71
|
+
),
|
72
|
+
parsing_info = list(
|
73
|
+
header = header,
|
74
|
+
separator = sep,
|
75
|
+
na_strings = na_strings,
|
76
|
+
rows_skipped = skip_rows
|
77
|
+
)
|
78
|
+
)
|
79
|
+
'''
|
80
|
+
|
81
|
+
try:
|
82
|
+
result = execute_r_script(r_script, params)
|
83
|
+
await context.info("CSV file read successfully",
|
84
|
+
rows=result["file_info"]["n_rows"],
|
85
|
+
cols=result["file_info"]["n_cols"])
|
86
|
+
return result
|
87
|
+
|
88
|
+
except Exception as e:
|
89
|
+
await context.error("CSV reading failed", error=str(e))
|
90
|
+
raise
|
91
|
+
|
92
|
+
|
93
|
+
@tool(
|
94
|
+
name="write_csv",
|
95
|
+
input_schema={
|
96
|
+
"type": "object",
|
97
|
+
"properties": {
|
98
|
+
"data": table_schema(),
|
99
|
+
"file_path": {"type": "string"},
|
100
|
+
"include_rownames": {"type": "boolean", "default": False},
|
101
|
+
"na_string": {"type": "string", "default": ""},
|
102
|
+
"append": {"type": "boolean", "default": False}
|
103
|
+
},
|
104
|
+
"required": ["data", "file_path"]
|
105
|
+
},
|
106
|
+
description="Write data to CSV file with formatting options"
|
107
|
+
)
|
108
|
+
async def write_csv(context, params):
|
109
|
+
"""Write data to CSV file."""
|
110
|
+
|
111
|
+
await context.info("Writing CSV file", file_path=params.get("file_path"))
|
112
|
+
|
113
|
+
r_script = '''
|
114
|
+
data <- as.data.frame(args$data)
|
115
|
+
file_path <- args$file_path
|
116
|
+
include_rownames <- args$include_rownames %||% FALSE
|
117
|
+
na_string <- args$na_string %||% ""
|
118
|
+
append_mode <- args$append %||% FALSE
|
119
|
+
|
120
|
+
# Write CSV
|
121
|
+
write.csv(data, file_path, row.names = include_rownames, na = na_string, append = append_mode)
|
122
|
+
|
123
|
+
# Verify file was written
|
124
|
+
if (!file.exists(file_path)) {
|
125
|
+
stop(paste("Failed to write file:", file_path))
|
126
|
+
}
|
127
|
+
|
128
|
+
file_info <- file.info(file_path)
|
129
|
+
|
130
|
+
result <- list(
|
131
|
+
file_path = file_path,
|
132
|
+
rows_written = nrow(data),
|
133
|
+
cols_written = ncol(data),
|
134
|
+
file_size_bytes = file_info$size,
|
135
|
+
success = TRUE,
|
136
|
+
timestamp = as.character(Sys.time())
|
137
|
+
)
|
138
|
+
'''
|
139
|
+
|
140
|
+
try:
|
141
|
+
result = execute_r_script(r_script, params)
|
142
|
+
await context.info("CSV file written successfully")
|
143
|
+
return result
|
144
|
+
|
145
|
+
except Exception as e:
|
146
|
+
await context.error("CSV writing failed", error=str(e))
|
147
|
+
raise
|
148
|
+
|
149
|
+
|
150
|
+
@tool(
|
151
|
+
name="data_info",
|
152
|
+
input_schema={
|
153
|
+
"type": "object",
|
154
|
+
"properties": {
|
155
|
+
"data": table_schema(),
|
156
|
+
"include_sample": {"type": "boolean", "default": True},
|
157
|
+
"sample_size": {"type": "integer", "minimum": 1, "maximum": 100, "default": 5}
|
158
|
+
},
|
159
|
+
"required": ["data"]
|
160
|
+
},
|
161
|
+
description="Get comprehensive information about a dataset"
|
162
|
+
)
|
163
|
+
async def data_info(context, params):
|
164
|
+
"""Get comprehensive dataset information."""
|
165
|
+
|
166
|
+
await context.info("Analyzing dataset structure")
|
167
|
+
|
168
|
+
r_script = '''
|
169
|
+
data <- as.data.frame(args$data)
|
170
|
+
include_sample <- args$include_sample %||% TRUE
|
171
|
+
sample_size <- args$sample_size %||% 5
|
172
|
+
|
173
|
+
# Basic info
|
174
|
+
n_rows <- nrow(data)
|
175
|
+
n_cols <- ncol(data)
|
176
|
+
col_names <- names(data)
|
177
|
+
|
178
|
+
# Variable types
|
179
|
+
var_types <- sapply(data, class)
|
180
|
+
numeric_vars <- names(data)[sapply(data, is.numeric)]
|
181
|
+
character_vars <- names(data)[sapply(data, is.character)]
|
182
|
+
factor_vars <- names(data)[sapply(data, is.factor)]
|
183
|
+
logical_vars <- names(data)[sapply(data, is.logical)]
|
184
|
+
date_vars <- names(data)[sapply(data, function(x) inherits(x, "Date"))]
|
185
|
+
|
186
|
+
# Missing value analysis
|
187
|
+
missing_counts <- sapply(data, function(x) sum(is.na(x)))
|
188
|
+
missing_percentages <- missing_counts / n_rows * 100
|
189
|
+
|
190
|
+
# Memory usage
|
191
|
+
memory_usage <- object.size(data)
|
192
|
+
|
193
|
+
result <- list(
|
194
|
+
dimensions = list(rows = n_rows, columns = n_cols),
|
195
|
+
variables = list(
|
196
|
+
all = col_names,
|
197
|
+
numeric = numeric_vars,
|
198
|
+
character = character_vars,
|
199
|
+
factor = factor_vars,
|
200
|
+
logical = logical_vars,
|
201
|
+
date = date_vars
|
202
|
+
),
|
203
|
+
variable_types = as.list(var_types),
|
204
|
+
missing_values = list(
|
205
|
+
counts = as.list(missing_counts),
|
206
|
+
percentages = as.list(missing_percentages),
|
207
|
+
total_missing = sum(missing_counts),
|
208
|
+
complete_cases = sum(complete.cases(data))
|
209
|
+
),
|
210
|
+
memory_usage_bytes = as.numeric(memory_usage)
|
211
|
+
)
|
212
|
+
|
213
|
+
# Add data sample if requested
|
214
|
+
if (include_sample && n_rows > 0) {
|
215
|
+
sample_rows <- min(sample_size, n_rows)
|
216
|
+
result$sample_data <- head(data, sample_rows)
|
217
|
+
}
|
218
|
+
'''
|
219
|
+
|
220
|
+
try:
|
221
|
+
result = execute_r_script(r_script, params)
|
222
|
+
await context.info("Dataset analysis completed successfully")
|
223
|
+
return result
|
224
|
+
|
225
|
+
except Exception as e:
|
226
|
+
await context.error("Dataset analysis failed", error=str(e))
|
227
|
+
raise
|
228
|
+
|
229
|
+
|
230
|
+
@tool(
|
231
|
+
name="filter_data",
|
232
|
+
input_schema={
|
233
|
+
"type": "object",
|
234
|
+
"properties": {
|
235
|
+
"data": table_schema(),
|
236
|
+
"conditions": {
|
237
|
+
"type": "array",
|
238
|
+
"items": {
|
239
|
+
"type": "object",
|
240
|
+
"properties": {
|
241
|
+
"variable": {"type": "string"},
|
242
|
+
"operator": {"type": "string", "enum": ["==", "!=", ">", "<", ">=", "<=", "%in%", "!%in%"]},
|
243
|
+
"value": {}
|
244
|
+
},
|
245
|
+
"required": ["variable", "operator", "value"]
|
246
|
+
}
|
247
|
+
},
|
248
|
+
"logic": {"type": "string", "enum": ["AND", "OR"], "default": "AND"}
|
249
|
+
},
|
250
|
+
"required": ["data", "conditions"]
|
251
|
+
},
|
252
|
+
description="Filter data based on multiple conditions"
|
253
|
+
)
|
254
|
+
async def filter_data(context, params):
|
255
|
+
"""Filter data based on conditions."""
|
256
|
+
|
257
|
+
await context.info("Filtering data")
|
258
|
+
|
259
|
+
r_script = '''
|
260
|
+
if (!require(dplyr)) install.packages("dplyr", quietly = TRUE)
|
261
|
+
library(dplyr)
|
262
|
+
|
263
|
+
data <- as.data.frame(args$data)
|
264
|
+
conditions <- args$conditions
|
265
|
+
logic <- args$logic %||% "AND"
|
266
|
+
|
267
|
+
# Build filter expressions
|
268
|
+
filter_expressions <- c()
|
269
|
+
|
270
|
+
for (condition in conditions) {
|
271
|
+
var <- condition$variable
|
272
|
+
op <- condition$operator
|
273
|
+
val <- condition$value
|
274
|
+
|
275
|
+
if (op == "%in%") {
|
276
|
+
expr <- paste0(var, " %in% c(", paste(paste0("'", val, "'"), collapse = ","), ")")
|
277
|
+
} else if (op == "!%in%") {
|
278
|
+
expr <- paste0("!(", var, " %in% c(", paste(paste0("'", val, "'"), collapse = ","), "))")
|
279
|
+
} else if (is.character(val)) {
|
280
|
+
expr <- paste0(var, " ", op, " '", val, "'")
|
281
|
+
} else {
|
282
|
+
expr <- paste0(var, " ", op, " ", val)
|
283
|
+
}
|
284
|
+
|
285
|
+
filter_expressions <- c(filter_expressions, expr)
|
286
|
+
}
|
287
|
+
|
288
|
+
# Combine expressions
|
289
|
+
if (logic == "AND") {
|
290
|
+
full_expression <- paste(filter_expressions, collapse = " & ")
|
291
|
+
} else {
|
292
|
+
full_expression <- paste(filter_expressions, collapse = " | ")
|
293
|
+
}
|
294
|
+
|
295
|
+
# Apply filter
|
296
|
+
filtered_data <- data %>% filter(eval(parse(text = full_expression)))
|
297
|
+
|
298
|
+
result <- list(
|
299
|
+
data = filtered_data,
|
300
|
+
filter_expression = full_expression,
|
301
|
+
original_rows = nrow(data),
|
302
|
+
filtered_rows = nrow(filtered_data),
|
303
|
+
rows_removed = nrow(data) - nrow(filtered_data),
|
304
|
+
removal_percentage = (nrow(data) - nrow(filtered_data)) / nrow(data) * 100
|
305
|
+
)
|
306
|
+
'''
|
307
|
+
|
308
|
+
try:
|
309
|
+
result = execute_r_script(r_script, params)
|
310
|
+
await context.info("Data filtered successfully")
|
311
|
+
return result
|
312
|
+
|
313
|
+
except Exception as e:
|
314
|
+
await context.error("Data filtering failed", error=str(e))
|
315
|
+
raise
|
@@ -0,0 +1,299 @@
|
|
1
|
+
"""
|
2
|
+
Machine learning tools for RMCP.
|
3
|
+
|
4
|
+
Clustering, classification trees, and ML capabilities.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Dict, Any
|
8
|
+
from ..registries.tools import tool
|
9
|
+
from ..core.schemas import table_schema, formula_schema
|
10
|
+
from ..r_integration import execute_r_script
|
11
|
+
|
12
|
+
|
13
|
+
@tool(
|
14
|
+
name="kmeans_clustering",
|
15
|
+
input_schema={
|
16
|
+
"type": "object",
|
17
|
+
"properties": {
|
18
|
+
"data": table_schema(),
|
19
|
+
"variables": {"type": "array", "items": {"type": "string"}},
|
20
|
+
"k": {"type": "integer", "minimum": 2, "maximum": 20},
|
21
|
+
"max_iter": {"type": "integer", "minimum": 1, "default": 100},
|
22
|
+
"nstart": {"type": "integer", "minimum": 1, "default": 25}
|
23
|
+
},
|
24
|
+
"required": ["data", "variables", "k"]
|
25
|
+
},
|
26
|
+
description="K-means clustering analysis with cluster validation"
|
27
|
+
)
|
28
|
+
async def kmeans_clustering(context, params):
|
29
|
+
"""Perform K-means clustering."""
|
30
|
+
|
31
|
+
await context.info("Performing K-means clustering")
|
32
|
+
|
33
|
+
r_script = '''
|
34
|
+
data <- as.data.frame(args$data)
|
35
|
+
variables <- args$variables
|
36
|
+
k <- args$k
|
37
|
+
max_iter <- args$max_iter %||% 100
|
38
|
+
nstart <- args$nstart %||% 25
|
39
|
+
|
40
|
+
# Select and prepare data
|
41
|
+
cluster_data <- data[, variables, drop = FALSE]
|
42
|
+
cluster_data <- na.omit(cluster_data)
|
43
|
+
|
44
|
+
# Scale variables for clustering
|
45
|
+
scaled_data <- scale(cluster_data)
|
46
|
+
|
47
|
+
# Perform k-means
|
48
|
+
set.seed(123) # For reproducibility
|
49
|
+
kmeans_result <- kmeans(scaled_data, centers = k, iter.max = max_iter, nstart = nstart)
|
50
|
+
|
51
|
+
# Calculate cluster statistics
|
52
|
+
cluster_centers <- kmeans_result$centers
|
53
|
+
cluster_assignments <- kmeans_result$cluster
|
54
|
+
|
55
|
+
# Within-cluster sum of squares
|
56
|
+
wss <- kmeans_result$withinss
|
57
|
+
total_wss <- kmeans_result$tot.withinss
|
58
|
+
between_ss <- kmeans_result$betweenss
|
59
|
+
total_ss <- kmeans_result$totss
|
60
|
+
|
61
|
+
# Cluster sizes
|
62
|
+
cluster_sizes <- table(cluster_assignments)
|
63
|
+
|
64
|
+
# Silhouette analysis (if package available)
|
65
|
+
silhouette_score <- NA
|
66
|
+
if (require(cluster, quietly = TRUE)) {
|
67
|
+
library(cluster)
|
68
|
+
sil <- silhouette(cluster_assignments, dist(scaled_data))
|
69
|
+
silhouette_score <- mean(sil[, 3])
|
70
|
+
}
|
71
|
+
|
72
|
+
result <- list(
|
73
|
+
cluster_assignments = cluster_assignments,
|
74
|
+
cluster_centers = as.data.frame(cluster_centers),
|
75
|
+
cluster_sizes = as.list(cluster_sizes),
|
76
|
+
within_ss = wss,
|
77
|
+
total_within_ss = total_wss,
|
78
|
+
between_ss = between_ss,
|
79
|
+
total_ss = total_ss,
|
80
|
+
variance_explained = between_ss / total_ss * 100,
|
81
|
+
silhouette_score = silhouette_score,
|
82
|
+
k = k,
|
83
|
+
variables = variables,
|
84
|
+
n_obs = nrow(cluster_data),
|
85
|
+
converged = kmeans_result$iter < max_iter
|
86
|
+
)
|
87
|
+
'''
|
88
|
+
|
89
|
+
try:
|
90
|
+
result = execute_r_script(r_script, params)
|
91
|
+
await context.info("K-means clustering completed successfully")
|
92
|
+
return result
|
93
|
+
|
94
|
+
except Exception as e:
|
95
|
+
await context.error("K-means clustering failed", error=str(e))
|
96
|
+
raise
|
97
|
+
|
98
|
+
|
99
|
+
@tool(
|
100
|
+
name="decision_tree",
|
101
|
+
input_schema={
|
102
|
+
"type": "object",
|
103
|
+
"properties": {
|
104
|
+
"data": table_schema(),
|
105
|
+
"formula": formula_schema(),
|
106
|
+
"type": {"type": "string", "enum": ["classification", "regression"], "default": "classification"},
|
107
|
+
"min_split": {"type": "integer", "minimum": 1, "default": 20},
|
108
|
+
"max_depth": {"type": "integer", "minimum": 1, "default": 30}
|
109
|
+
},
|
110
|
+
"required": ["data", "formula"]
|
111
|
+
},
|
112
|
+
description="Decision tree classification and regression"
|
113
|
+
)
|
114
|
+
async def decision_tree(context, params):
|
115
|
+
"""Build decision tree model."""
|
116
|
+
|
117
|
+
await context.info("Building decision tree")
|
118
|
+
|
119
|
+
r_script = '''
|
120
|
+
if (!require(rpart)) install.packages("rpart", quietly = TRUE)
|
121
|
+
library(rpart)
|
122
|
+
|
123
|
+
data <- as.data.frame(args$data)
|
124
|
+
formula <- as.formula(args$formula)
|
125
|
+
tree_type <- args$type %||% "classification"
|
126
|
+
min_split <- args$min_split %||% 20
|
127
|
+
max_depth <- args$max_depth %||% 30
|
128
|
+
|
129
|
+
# Set method based on type
|
130
|
+
if (tree_type == "classification") {
|
131
|
+
method <- "class"
|
132
|
+
} else {
|
133
|
+
method <- "anova"
|
134
|
+
}
|
135
|
+
|
136
|
+
# Build tree
|
137
|
+
tree_model <- rpart(formula, data = data, method = method,
|
138
|
+
control = rpart.control(minsplit = min_split, maxdepth = max_depth))
|
139
|
+
|
140
|
+
# Get predictions
|
141
|
+
predictions <- predict(tree_model, type = if (method == "class") "class" else "vector")
|
142
|
+
|
143
|
+
# Calculate performance metrics
|
144
|
+
if (tree_type == "classification") {
|
145
|
+
# Classification metrics
|
146
|
+
response_var <- all.vars(formula)[1]
|
147
|
+
actual <- data[[response_var]]
|
148
|
+
confusion_matrix <- table(Predicted = predictions, Actual = actual)
|
149
|
+
accuracy <- sum(diag(confusion_matrix)) / sum(confusion_matrix)
|
150
|
+
|
151
|
+
performance <- list(
|
152
|
+
accuracy = accuracy,
|
153
|
+
confusion_matrix = as.matrix(confusion_matrix)
|
154
|
+
)
|
155
|
+
} else {
|
156
|
+
# Regression metrics
|
157
|
+
response_var <- all.vars(formula)[1]
|
158
|
+
actual <- data[[response_var]]
|
159
|
+
mse <- mean((predictions - actual)^2, na.rm = TRUE)
|
160
|
+
rmse <- sqrt(mse)
|
161
|
+
r_squared <- 1 - sum((actual - predictions)^2, na.rm = TRUE) / sum((actual - mean(actual, na.rm = TRUE))^2, na.rm = TRUE)
|
162
|
+
|
163
|
+
performance <- list(
|
164
|
+
mse = mse,
|
165
|
+
rmse = rmse,
|
166
|
+
r_squared = r_squared
|
167
|
+
)
|
168
|
+
}
|
169
|
+
|
170
|
+
# Variable importance
|
171
|
+
var_importance <- tree_model$variable.importance
|
172
|
+
|
173
|
+
result <- list(
|
174
|
+
tree_type = tree_type,
|
175
|
+
performance = performance,
|
176
|
+
variable_importance = as.list(var_importance),
|
177
|
+
predictions = as.numeric(predictions),
|
178
|
+
n_nodes = nrow(tree_model$frame),
|
179
|
+
n_obs = nrow(data),
|
180
|
+
formula = deparse(formula),
|
181
|
+
tree_complexity = tree_model$cptable[nrow(tree_model$cptable), "CP"]
|
182
|
+
)
|
183
|
+
'''
|
184
|
+
|
185
|
+
try:
|
186
|
+
result = execute_r_script(r_script, params)
|
187
|
+
await context.info("Decision tree built successfully")
|
188
|
+
return result
|
189
|
+
|
190
|
+
except Exception as e:
|
191
|
+
await context.error("Decision tree building failed", error=str(e))
|
192
|
+
raise
|
193
|
+
|
194
|
+
|
195
|
+
@tool(
|
196
|
+
name="random_forest",
|
197
|
+
input_schema={
|
198
|
+
"type": "object",
|
199
|
+
"properties": {
|
200
|
+
"data": table_schema(),
|
201
|
+
"formula": formula_schema(),
|
202
|
+
"n_trees": {"type": "integer", "minimum": 1, "maximum": 1000, "default": 500},
|
203
|
+
"mtry": {"type": "integer", "minimum": 1},
|
204
|
+
"importance": {"type": "boolean", "default": True}
|
205
|
+
},
|
206
|
+
"required": ["data", "formula"]
|
207
|
+
},
|
208
|
+
description="Random Forest ensemble model for classification and regression"
|
209
|
+
)
|
210
|
+
async def random_forest(context, params):
|
211
|
+
"""Build Random Forest model."""
|
212
|
+
|
213
|
+
await context.info("Building Random Forest model")
|
214
|
+
|
215
|
+
r_script = '''
|
216
|
+
if (!require(randomForest)) install.packages("randomForest", quietly = TRUE)
|
217
|
+
library(randomForest)
|
218
|
+
|
219
|
+
data <- as.data.frame(args$data)
|
220
|
+
formula <- as.formula(args$formula)
|
221
|
+
n_trees <- args$n_trees %||% 500
|
222
|
+
mtry_val <- args$mtry
|
223
|
+
importance <- args$importance %||% TRUE
|
224
|
+
|
225
|
+
# Determine problem type
|
226
|
+
response_var <- all.vars(formula)[1]
|
227
|
+
if (is.factor(data[[response_var]]) || is.character(data[[response_var]])) {
|
228
|
+
# Convert to factor if character
|
229
|
+
if (is.character(data[[response_var]])) {
|
230
|
+
data[[response_var]] <- as.factor(data[[response_var]])
|
231
|
+
}
|
232
|
+
problem_type <- "classification"
|
233
|
+
} else {
|
234
|
+
problem_type <- "regression"
|
235
|
+
}
|
236
|
+
|
237
|
+
# Set default mtry if not provided
|
238
|
+
if (is.null(mtry_val)) {
|
239
|
+
n_predictors <- length(all.vars(formula)[-1])
|
240
|
+
if (problem_type == "classification") {
|
241
|
+
mtry_val <- floor(sqrt(n_predictors))
|
242
|
+
} else {
|
243
|
+
mtry_val <- floor(n_predictors / 3)
|
244
|
+
}
|
245
|
+
}
|
246
|
+
|
247
|
+
# Build Random Forest
|
248
|
+
rf_model <- randomForest(formula, data = data, ntree = n_trees,
|
249
|
+
mtry = mtry_val, importance = importance)
|
250
|
+
|
251
|
+
# Extract results
|
252
|
+
if (problem_type == "classification") {
|
253
|
+
confusion_matrix <- rf_model$confusion[, -ncol(rf_model$confusion)] # Remove class.error column
|
254
|
+
oob_error <- rf_model$err.rate[n_trees, "OOB"]
|
255
|
+
|
256
|
+
performance <- list(
|
257
|
+
oob_error_rate = oob_error,
|
258
|
+
confusion_matrix = as.matrix(confusion_matrix),
|
259
|
+
class_error = as.list(rf_model$confusion[, "class.error"])
|
260
|
+
)
|
261
|
+
} else {
|
262
|
+
mse <- rf_model$mse[n_trees]
|
263
|
+
variance_explained <- (1 - mse / var(data[[response_var]], na.rm = TRUE)) * 100
|
264
|
+
|
265
|
+
performance <- list(
|
266
|
+
mse = mse,
|
267
|
+
rmse = sqrt(mse),
|
268
|
+
variance_explained = variance_explained
|
269
|
+
)
|
270
|
+
}
|
271
|
+
|
272
|
+
# Variable importance
|
273
|
+
if (importance) {
|
274
|
+
var_imp <- importance(rf_model)
|
275
|
+
var_importance <- as.data.frame(var_imp)
|
276
|
+
} else {
|
277
|
+
var_importance <- NULL
|
278
|
+
}
|
279
|
+
|
280
|
+
result <- list(
|
281
|
+
problem_type = problem_type,
|
282
|
+
performance = performance,
|
283
|
+
variable_importance = var_importance,
|
284
|
+
n_trees = n_trees,
|
285
|
+
mtry = rf_model$mtry,
|
286
|
+
oob_error = rf_model$err.rate[n_trees, 1],
|
287
|
+
formula = deparse(formula),
|
288
|
+
n_obs = nrow(data)
|
289
|
+
)
|
290
|
+
'''
|
291
|
+
|
292
|
+
try:
|
293
|
+
result = execute_r_script(r_script, params)
|
294
|
+
await context.info("Random Forest model built successfully")
|
295
|
+
return result
|
296
|
+
|
297
|
+
except Exception as e:
|
298
|
+
await context.error("Random Forest building failed", error=str(e))
|
299
|
+
raise
|