linkml-store 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of linkml-store might be problematic. Click here for more details.

Files changed (31) hide show
  1. linkml_store/api/client.py +37 -8
  2. linkml_store/api/collection.py +81 -9
  3. linkml_store/api/config.py +28 -1
  4. linkml_store/api/database.py +26 -3
  5. linkml_store/api/stores/mongodb/mongodb_collection.py +4 -0
  6. linkml_store/api/stores/neo4j/__init__.py +0 -0
  7. linkml_store/api/stores/neo4j/neo4j_collection.py +429 -0
  8. linkml_store/api/stores/neo4j/neo4j_database.py +154 -0
  9. linkml_store/cli.py +140 -13
  10. linkml_store/graphs/__init__.py +0 -0
  11. linkml_store/graphs/graph_map.py +24 -0
  12. linkml_store/inference/__init__.py +13 -0
  13. linkml_store/inference/implementations/__init__.py +0 -0
  14. linkml_store/inference/implementations/rag_inference_engine.py +145 -0
  15. linkml_store/inference/implementations/rule_based_inference_engine.py +158 -0
  16. linkml_store/inference/implementations/sklearn_inference_engine.py +290 -0
  17. linkml_store/inference/inference_config.py +62 -0
  18. linkml_store/inference/inference_engine.py +173 -0
  19. linkml_store/inference/inference_engine_registry.py +74 -0
  20. linkml_store/utils/format_utils.py +21 -90
  21. linkml_store/utils/llm_utils.py +95 -0
  22. linkml_store/utils/neo4j_utils.py +42 -0
  23. linkml_store/utils/object_utils.py +3 -1
  24. linkml_store/utils/pandas_utils.py +55 -2
  25. linkml_store/utils/sklearn_utils.py +193 -0
  26. linkml_store/utils/stats_utils.py +53 -0
  27. {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/METADATA +30 -3
  28. {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/RECORD +31 -14
  29. {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/LICENSE +0 -0
  30. {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/WHEEL +0 -0
  31. {linkml_store-0.1.12.dist-info → linkml_store-0.1.14.dist-info}/entry_points.txt +0 -0
@@ -27,6 +27,7 @@ class Format(Enum):
27
27
  JSON = "json"
28
28
  JSONL = "jsonl"
29
29
  YAML = "yaml"
30
+ YAMLL = "yamll"
30
31
  TSV = "tsv"
31
32
  CSV = "csv"
32
33
  PYTHON = "python"
@@ -63,6 +64,9 @@ class Format(Enum):
63
64
  def is_dump_format(self):
64
65
  return self in [Format.SQLDUMP_DUCKDB, Format.SQLDUMP_POSTGRES, Format.DUMP_MONGODB]
65
66
 
67
+ def is_xsv(self):
68
+ return self in [Format.TSV, Format.CSV]
69
+
66
70
 
67
71
  def load_objects_from_url(
68
72
  url: str,
@@ -135,11 +139,14 @@ def load_objects(
135
139
  compression: Optional[str] = None,
136
140
  expected_type: Optional[Type] = None,
137
141
  header_comment_token: Optional[str] = None,
142
+ select_query: Optional[str] = None,
138
143
  ) -> List[Dict[str, Any]]:
139
144
  """
140
145
  Load objects from a file or archive in supported formats.
141
146
  For tgz archives, it processes all files and concatenates the results.
142
147
 
148
+ TODO: Add schema hints for CSV/TSV parsing.
149
+
143
150
  :param file_path: The path to the file or archive.
144
151
  :param format: The format of the file. Can be a Format enum or a string value.
145
152
  :param compression: The compression type. Supports 'gz' for gzip and 'tgz' for tar.gz.
@@ -177,98 +184,22 @@ def load_objects(
177
184
  all_objects = process_file(f, format, expected_type, header_comment_token)
178
185
 
179
186
  logger.debug(f"Loaded {len(all_objects)} objects from {file_path}")
187
+ if select_query:
188
+ import jsonpath_ng as jp
189
+
190
+ path_expr = jp.parse(select_query)
191
+ new_objs = []
192
+ for obj in all_objects:
193
+ for match in path_expr.find(obj):
194
+ logging.debug(f"Match: {match.value}")
195
+ if isinstance(match.value, list):
196
+ new_objs.extend(match.value)
197
+ else:
198
+ new_objs.append(match.value)
199
+ all_objects = new_objs
180
200
  return all_objects
181
201
 
182
202
 
183
- def xxxload_objects(
184
- file_path: Union[str, Path],
185
- format: Union[Format, str] = None,
186
- compression: Optional[str] = None,
187
- expected_type: Type = None,
188
- header_comment_token: Optional[str] = None,
189
- ) -> List[Dict[str, Any]]:
190
- """
191
- Load objects from a file in JSON, JSONLines, YAML, CSV, or TSV format.
192
-
193
- >>> load_objects("tests/input/test_data/data.csv")
194
- [{'id': '1', 'name': 'John', 'age': '30'},
195
- {'id': '2', 'name': 'Alice', 'age': '25'}, {'id': '3', 'name': 'Bob', 'age': '35'}]
196
-
197
- :param file_path: The path to the file.
198
- :param format: The format of the file. Can be a Format enum or a string value.
199
- :param expected_type: The target type to load the objects into, e.g. list
200
- :return: A list of dictionaries representing the loaded objects.
201
- """
202
- if isinstance(format, str):
203
- format = Format(format)
204
-
205
- if isinstance(file_path, Path):
206
- file_path = str(file_path)
207
-
208
- if not format and (file_path.endswith(".parquet") or file_path.endswith(".pq")):
209
- format = Format.PARQUET
210
- if not format and file_path.endswith(".tsv"):
211
- format = Format.TSV
212
- if not format and file_path.endswith(".csv"):
213
- format = Format.CSV
214
- if not format and file_path.endswith(".py"):
215
- format = Format.PYTHON
216
-
217
- mode = "r"
218
- if format == Format.PARQUET:
219
- mode = "rb"
220
-
221
- if file_path == "-":
222
- # set file_path to be a stream from stdin
223
- f = sys.stdin
224
- else:
225
- f = open(file_path, mode)
226
-
227
- if format == Format.JSON or (not format and file_path.endswith(".json")):
228
- objs = json.load(f)
229
- elif format == Format.JSONL or (not format and file_path.endswith(".jsonl")):
230
- objs = [json.loads(line) for line in f]
231
- elif format == Format.YAML or (not format and (file_path.endswith(".yaml") or file_path.endswith(".yml"))):
232
- if expected_type and expected_type == list: # noqa E721
233
- objs = list(yaml.safe_load_all(f))
234
- else:
235
- objs = yaml.safe_load(f)
236
- elif format == Format.TSV or format == Format.CSV:
237
- # Skip initial comment lines if comment_char is set
238
- if header_comment_token:
239
- # Store the original position
240
- original_pos = f.tell()
241
-
242
- # Read and store lines until we find a non-comment line
243
- lines = []
244
- for line in f:
245
- if not line.startswith(header_comment_token):
246
- break
247
- lines.append(line)
248
-
249
- # Go back to the original position
250
- f.seek(original_pos)
251
-
252
- # Skip the comment lines we found
253
- for _ in lines:
254
- f.readline()
255
- if format == Format.TSV:
256
- reader = csv.DictReader(f, delimiter="\t")
257
- else:
258
- reader = csv.DictReader(f)
259
- objs = list(reader)
260
- elif format == Format.PARQUET:
261
- import pyarrow.parquet as pq
262
-
263
- table = pq.read_table(f)
264
- objs = table.to_pandas().to_dict(orient="records")
265
- else:
266
- raise ValueError(f"Unsupported file format: {file_path}")
267
- if not isinstance(objs, list):
268
- objs = [objs]
269
- return objs
270
-
271
-
272
203
  def write_output(
273
204
  data: Union[List[Dict[str, Any]], Dict[str, Any], pd.DataFrame],
274
205
  format: Union[Format, str] = Format.YAML,
@@ -329,7 +260,7 @@ def render_output(
329
260
  if format == Format.FORMATTED:
330
261
  if not isinstance(data, pd.DataFrame):
331
262
  data = pd.DataFrame(data)
332
- return str(data)
263
+ return data.to_string(max_rows=None)
333
264
 
334
265
  if isinstance(data, pd.DataFrame):
335
266
  data = data.to_dict(orient="records")
@@ -0,0 +1,95 @@
1
+ from typing import Callable, List, Optional
2
+
3
+ from tiktoken import Encoding
4
+
5
+ MODEL_TOKEN_MAPPING = {
6
+ "gpt-4o-mini": 128_000,
7
+ "gpt-4o": 128_000,
8
+ "gpt-4o-2024-05-13": 128_000,
9
+ "gpt-4": 8192,
10
+ "gpt-4-0314": 8192,
11
+ "gpt-4-0613": 8192,
12
+ "gpt-4-32k": 32768,
13
+ "gpt-4-32k-0314": 32768,
14
+ "gpt-4-32k-0613": 32768,
15
+ "gpt-3.5-turbo": 4096,
16
+ "gpt-3.5-turbo-0301": 4096,
17
+ "gpt-3.5-turbo-0613": 4096,
18
+ "gpt-3.5-turbo-16k": 16385,
19
+ "gpt-3.5-turbo-16k-0613": 16385,
20
+ "gpt-3.5-turbo-instruct": 4096,
21
+ "text-ada-001": 2049,
22
+ "ada": 2049,
23
+ "text-babbage-001": 2040,
24
+ "babbage": 2049,
25
+ "text-curie-001": 2049,
26
+ "curie": 2049,
27
+ "davinci": 2049,
28
+ "text-davinci-003": 4097,
29
+ "text-davinci-002": 4097,
30
+ "code-davinci-002": 8001,
31
+ "code-davinci-001": 8001,
32
+ "code-cushman-002": 2048,
33
+ "code-cushman-001": 2048,
34
+ "claude": 200_000,
35
+ }
36
+
37
+
38
+ def render_formatted_text(
39
+ render_func: Callable,
40
+ values: List[str],
41
+ encoding: Encoding,
42
+ token_limit: int,
43
+ additional_text: Optional[str] = None,
44
+ ) -> str:
45
+ """
46
+ Render a formatted text string with a given object, encoding, and token limit.
47
+
48
+ >>> from tiktoken import encoding_for_model
49
+ >>> encoding = encoding_for_model("gpt-4o-mini")
50
+ >>> names = ["Alice", "Bob", "DoctorHippopotamusMcHippopotamusFace"]
51
+ >>> f = lambda x: f"Hello, {' '.join(x)}!"
52
+ >>> render_formatted_text(f, names, encoding, 4096)
53
+ 'Hello, Alice Bob DoctorHippopotamusMcHippopotamusFace!'
54
+ >>> render_formatted_text(f, names, encoding, 5)
55
+ 'Hello, Alice Bob!'
56
+
57
+ :param render_func: Rendering function
58
+ :param values: Values to render
59
+ :param encoding: Encoding
60
+ :param token_limit: Token limit
61
+ :param additional_text: Additional text to consider
62
+ :return:
63
+ """
64
+ text = render_func(values)
65
+ if additional_text:
66
+ token_limit -= len(encoding.encode(additional_text))
67
+ text_length = len(encoding.encode(text))
68
+ if text_length <= token_limit:
69
+ return text
70
+ if not values:
71
+ raise ValueError(f"Cannot fit text into token limit: {text_length} > {token_limit}")
72
+ return render_formatted_text(render_func, values[0:-1], encoding=encoding, token_limit=token_limit)
73
+
74
+
75
+ def get_token_limit(model_name: str) -> int:
76
+ """
77
+ Estimate the token limit for a model.
78
+
79
+ >>> get_token_limit("gpt-4o-mini")
80
+ 128000
81
+
82
+ also works with nested names:
83
+
84
+ >>> get_token_limit("my/claude-opus")
85
+ 200000
86
+
87
+
88
+ :param model_name: Model name
89
+ :return: Estimated token limit
90
+ """
91
+ # sort MODEL_TOKEN_MAPPING by key length to ensure that the longest model names are checked first
92
+ for model, token_limit in sorted(MODEL_TOKEN_MAPPING.items(), key=lambda x: len(x[0]), reverse=True):
93
+ if model in model_name:
94
+ return token_limit
95
+ return 4096
@@ -0,0 +1,42 @@
1
+ import networkx as nx
2
+ from py2neo import Graph
3
+
4
+
5
+ def draw_neo4j_graph(handle="bolt://localhost:7687", auth=("neo4j", None)):
6
+ # Connect to Neo4j
7
+ graph = Graph(handle, auth=auth)
8
+
9
+ # Run a Cypher query
10
+ query = """
11
+ MATCH (n)-[r]->(m)
12
+ RETURN n, r, m
13
+ LIMIT 100
14
+ """
15
+ result = graph.run(query)
16
+
17
+ # Create a NetworkX graph
18
+ G = nx.DiGraph() # Use DiGraph for directed edges
19
+ for record in result:
20
+ n = record["n"]
21
+ m = record["m"]
22
+ r = record["r"]
23
+ G.add_node(n["name"], label=list(n.labels or ["-"])[0])
24
+ G.add_node(m["name"], label=list(m.labels or ["-"])[0])
25
+ G.add_edge(n["name"], m["name"], type=type(r).__name__)
26
+
27
+ # Draw the graph
28
+ pos = nx.spring_layout(G)
29
+
30
+ # Draw nodes
31
+ nx.draw_networkx_nodes(G, pos, node_color="lightblue", node_size=10000)
32
+
33
+ # Draw edges
34
+ nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True)
35
+
36
+ # Add node labels
37
+ node_labels = nx.get_node_attributes(G, "label")
38
+ nx.draw_networkx_labels(G, pos, {node: f"{node}\n({label})" for node, label in node_labels.items()}, font_size=16)
39
+
40
+ # Add edge labels
41
+ edge_labels = nx.get_edge_attributes(G, "type")
42
+ nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=16)
@@ -29,7 +29,7 @@ def object_path_update(
29
29
  """
30
30
  if isinstance(obj, BaseModel):
31
31
  typ = type(obj)
32
- obj = obj.dict()
32
+ obj = obj.model_dump(exclude_none=True)
33
33
  obj = object_path_update(obj, path, value)
34
34
  return typ(**obj)
35
35
  obj = deepcopy(obj)
@@ -45,6 +45,8 @@ def object_path_update(
45
45
  obj.append({})
46
46
  obj = obj[index]
47
47
  else:
48
+ if part in obj and obj[part] is None:
49
+ del obj[part]
48
50
  obj = obj.setdefault(part, {})
49
51
  last_part = parts[-1]
50
52
  if "[" in last_part:
@@ -1,7 +1,59 @@
1
- from typing import Dict, List, Tuple, Union
1
+ import logging
2
+ from typing import Any, Dict, List, Tuple, Union
2
3
 
3
4
  import pandas as pd
4
5
 
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ def flatten_dict(d: Dict[str, Any], parent_key: str = "", sep: str = ".") -> Dict[str, Any]:
10
+ """
11
+ Recursively flatten a nested dictionary.
12
+
13
+ Args:
14
+ d (Dict[str, Any]): The dictionary to flatten.
15
+ parent_key (str): The parent key for nested dictionaries.
16
+ sep (str): The separator to use between keys.
17
+
18
+ Returns:
19
+ Dict[str, Any]: A flattened dictionary.
20
+
21
+ >>> flatten_dict({'a': 1, 'b': {'c': 2, 'd': {'e': 3}}})
22
+ {'a': 1, 'b.c': 2, 'b.d.e': 3}
23
+ """
24
+ items = []
25
+ for k, v in d.items():
26
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
27
+ if isinstance(v, dict):
28
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
29
+ else:
30
+ items.append((new_key, v))
31
+ return dict(items)
32
+
33
+
34
+ def nested_objects_to_dataframe(data: List[Dict[str, Any]]) -> pd.DataFrame:
35
+ """
36
+ Convert a list of nested objects to a flattened pandas DataFrame.
37
+
38
+ Args:
39
+ data (List[Dict[str, Any]]): A list of nested dictionaries.
40
+
41
+ Returns:
42
+ pd.DataFrame: A flattened DataFrame.
43
+
44
+ >>> data = [
45
+ ... {"person": {"name": "Alice", "age": 30}, "job": {"title": "Engineer", "salary": 75000}},
46
+ ... {"person": {"name": "Bob", "age": 35}, "job": {"title": "Manager", "salary": 85000}}
47
+ ... ]
48
+ >>> df = nested_objects_to_dataframe(data)
49
+ >>> df.columns.tolist()
50
+ ['person.name', 'person.age', 'job.title', 'job.salary']
51
+ >>> df['person.name'].tolist()
52
+ ['Alice', 'Bob']
53
+ """
54
+ flattened_data = [flatten_dict(item) for item in data]
55
+ return pd.DataFrame(flattened_data)
56
+
5
57
 
6
58
  def facet_summary_to_dataframe_unmelted(
7
59
  facet_summary: Dict[Union[str, Tuple[str, ...]], List[Tuple[Union[str, Tuple[str, ...]], int]]]
@@ -22,7 +74,8 @@ def facet_summary_to_dataframe_unmelted(
22
74
  categories, value = cat_val_tuple[:-1], cat_val_tuple[-1]
23
75
  row = {"Value": value}
24
76
  for i, facet in enumerate(facet_type):
25
- row[facet] = categories[i]
77
+ logger.debug(f"FT={facet_type} i={i} Facet: {facet}, categories: {categories}")
78
+ row[facet] = categories[i] if len(categories) > i else None
26
79
  rows.append(row)
27
80
 
28
81
  df = pd.DataFrame(rows)
@@ -0,0 +1,193 @@
1
+ import logging
2
+ import os
3
+ import re
4
+ import shutil
5
+ from pathlib import Path
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ import numpy as np
9
+ from linkml_runtime.utils.formatutils import underscore
10
+ from sklearn.preprocessing import LabelEncoder, OneHotEncoder
11
+ from sklearn.tree import DecisionTreeClassifier, _tree, export_graphviz
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def tree_to_nested_expression(
17
+ tree: DecisionTreeClassifier,
18
+ feature_names: List[str],
19
+ categorical_features: Optional[List[str]] = None,
20
+ feature_encoders: Optional[Dict[str, Union[OneHotEncoder, LabelEncoder]]] = None,
21
+ target_encoder: Optional[LabelEncoder] = None,
22
+ ) -> str:
23
+ """
24
+ Convert a trained scikit-learn DecisionTreeClassifier to a nested Python conditional expression.
25
+
26
+ Args:
27
+ tree (DecisionTreeClassifier): A trained decision tree classifier.
28
+ feature_names (list): List of feature names (including one-hot encoded feature names).
29
+ categorical_features (list): List of original categorical feature names.
30
+ feature_encoders (dict): Dictionary mapping feature names to their respective OneHotEncoders or LabelEncoders.
31
+ target_encoder (LabelEncoder, optional): LabelEncoder for the target variable if it's categorical.
32
+
33
+ Returns:
34
+ str: A string representing the nested Python conditional expression.
35
+
36
+ Example:
37
+ >>> import numpy as np
38
+ >>> from sklearn.tree import DecisionTreeClassifier
39
+ >>> from sklearn.preprocessing import OneHotEncoder, LabelEncoder
40
+ >>>
41
+ >>> # Prepare sample data
42
+ >>> X = np.array([[0, 'A'], [0, 'B'], [1, 'A'], [1, 'B']])
43
+ >>> y = np.array(['No', 'Yes', 'Yes', 'No'])
44
+ >>>
45
+ >>> # Prepare the encoders
46
+ >>> feature_encoders = {'feature2': OneHotEncoder(sparse_output=False, handle_unknown='ignore')}
47
+ >>> target_encoder = LabelEncoder()
48
+ >>>
49
+ >>> # Encode the categorical feature and target
50
+ >>> X_encoded = np.column_stack([
51
+ ... X[:, 0],
52
+ ... feature_encoders['feature2'].fit_transform(X[:, 1].reshape(-1, 1))
53
+ ... ])
54
+ >>> y_encoded = target_encoder.fit_transform(y)
55
+ >>>
56
+ >>> # Train the decision tree
57
+ >>> clf = DecisionTreeClassifier(random_state=42)
58
+ >>> clf.fit(X_encoded, y_encoded)
59
+ DecisionTreeClassifier(random_state=42)
60
+ >>>
61
+ >>> # Convert to nested expression
62
+ >>> feature_names = ['feature1', 'feature2_A', 'feature2_B']
63
+ >>> categorical_features = ['feature2']
64
+ >>> expression = tree_to_nested_expression(clf, feature_names,
65
+ ... categorical_features, feature_encoders, target_encoder)
66
+ >>> print(expression)
67
+ (("Yes" if ({feature1} <= 0.5000) else "No") if ({feature2} == "A")
68
+ else ("No" if ({feature1} <= 0.5000) else "Yes"))
69
+ """
70
+ tree_ = tree.tree_
71
+ feature_name = [feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" for i in tree_.feature]
72
+
73
+ categorical_features = set(categorical_features or [])
74
+
75
+ def get_original_feature_name(name):
76
+ return name.split("_")[0] if "_" in name else name
77
+
78
+ def recurse(node):
79
+ if tree_.feature[node] != _tree.TREE_UNDEFINED:
80
+ name = feature_name[node]
81
+ threshold = tree_.threshold[node]
82
+ original_name = get_original_feature_name(name)
83
+ original_name_safe = underscore(original_name)
84
+ name_safe = underscore(name)
85
+
86
+ original_name_safe = "{" + original_name_safe + "}"
87
+ name_safe = "{" + name_safe + "}"
88
+
89
+ if original_name in categorical_features:
90
+ if feature_encoders is None or original_name not in feature_encoders:
91
+ raise ValueError(f"Encoder is required for categorical feature {original_name}")
92
+
93
+ encoder = feature_encoders[original_name]
94
+ if isinstance(encoder, OneHotEncoder):
95
+ # For one-hot encoded features, we check if the specific category is present
96
+ category = name.split("_", 1)[1] # Get everything after the first underscore
97
+ condition = f'{original_name_safe} == "{category}"'
98
+ elif isinstance(encoder, LabelEncoder):
99
+ category = encoder.inverse_transform([int(threshold)])[0]
100
+ condition = f'{original_name_safe} == "{category}"'
101
+ else:
102
+ raise ValueError(f"Unsupported encoder type for feature {original_name}")
103
+ else:
104
+ if np.isinf(threshold):
105
+ condition = "True"
106
+ else:
107
+ condition = f"{name_safe} <= {threshold:.4f}"
108
+
109
+ left_expr = recurse(tree_.children_left[node])
110
+ right_expr = recurse(tree_.children_right[node])
111
+
112
+ return f"({left_expr} if ({condition}) else {right_expr})"
113
+ else:
114
+ class_index = np.argmax(tree_.value[node])
115
+ if target_encoder:
116
+ class_label = target_encoder.inverse_transform([class_index])[0]
117
+ return f'"{class_label}"'
118
+ else:
119
+ return str(class_index)
120
+
121
+ return recurse(0)
122
+
123
+
124
+ def escape_label(s: str) -> str:
125
+ """Escape special characters in label strings."""
126
+ s = str(s)
127
+ return re.sub(r"([<>])", r"\\\1", s)
128
+
129
+
130
+ def visualize_decision_tree(
131
+ clf: DecisionTreeClassifier,
132
+ feature_names: List[str],
133
+ class_names: List[str] = None,
134
+ output_file: Union[Path, str] = "decision_tree.png",
135
+ ) -> None:
136
+ """
137
+ Generate a visualization of the decision tree and save it as a PNG file.
138
+
139
+ :param clf: Trained DecisionTreeClassifier
140
+ :param feature_names: List of feature names
141
+ :param class_names: List of class names (optional)
142
+ :param output_file: The name of the file to save the visualization (default: "decision_tree.png")
143
+
144
+ >>> # Create a sample dataset
145
+ >>> import pandas as pd
146
+ >>> data = pd.DataFrame({
147
+ ... 'age': [25, 30, 35, 40, 45],
148
+ ... 'income': [50000, 60000, 70000, 80000, 90000],
149
+ ... 'credit_score': [600, 650, 700, 750, 800],
150
+ ... 'approved': ['No', 'No', 'Yes', 'Yes', 'Yes']
151
+ ... })
152
+ >>>
153
+ >>> # Prepare features and target
154
+ >>> X = data[['age', 'income', 'credit_score']]
155
+ >>> y = data['approved']
156
+ >>>
157
+ >>> # Encode target variable
158
+ >>> le = LabelEncoder()
159
+ >>> y_encoded = le.fit_transform(y)
160
+ >>>
161
+ >>> # Train a decision tree
162
+ >>> clf = DecisionTreeClassifier(random_state=42)
163
+ >>> _ = clf.fit(X, y_encoded)
164
+ >>> # Visualize the tree
165
+ >>> visualize_decision_tree(clf, X.columns.tolist(), le.classes_, "tests/output/test_tree.png")
166
+ """
167
+ # Escape special characters in feature names and class names
168
+ escaped_feature_names = [escape_label(name) for name in feature_names]
169
+ escaped_class_names = [escape_label(name) for name in (class_names if class_names is not None else [])]
170
+
171
+ import graphviz
172
+
173
+ dot_data = export_graphviz(
174
+ clf,
175
+ out_file=None,
176
+ feature_names=escaped_feature_names,
177
+ class_names=escaped_class_names,
178
+ filled=True,
179
+ rounded=True,
180
+ special_characters=True,
181
+ )
182
+ # dot_data = escape_label(dot_data)
183
+ logger.info(f"Dot: {dot_data}")
184
+ dot_path = shutil.which("dot")
185
+ if not dot_path:
186
+ logger.warning("Graphviz 'dot' executable not found in PATH. Skipping visualization.")
187
+ return
188
+ os.environ["GRAPHVIZ_DOT"] = dot_path
189
+
190
+ graph = graphviz.Source(dot_data)
191
+ if isinstance(output_file, Path):
192
+ output_file = str(output_file)
193
+ graph.render(output_file.rsplit(".", 1)[0], format="png", cleanup=True)
@@ -0,0 +1,53 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+
5
+ def predictive_power(df, target_col, feature_cols, cv=5):
6
+ from sklearn.model_selection import cross_val_score
7
+ from sklearn.preprocessing import LabelEncoder
8
+ from sklearn.tree import DecisionTreeClassifier
9
+
10
+ # Prepare the data
11
+ X = df[feature_cols].copy() # Create an explicit copy
12
+ y = df[target_col].copy()
13
+
14
+ # Encode categorical variables
15
+ for col in X.columns:
16
+ if X[col].dtype == "object":
17
+ X[col] = LabelEncoder().fit_transform(X[col].astype(str))
18
+
19
+ if y.dtype == "object":
20
+ y = LabelEncoder().fit_transform(y.astype(str))
21
+
22
+ # Adjust cv based on the number of unique values in y
23
+ n_unique = len(np.unique(y))
24
+ cv = min(cv, n_unique)
25
+
26
+ # Train a decision tree and get cross-validated accuracy
27
+ clf = DecisionTreeClassifier(random_state=42)
28
+
29
+ if cv < 2:
30
+ # If cv is less than 2, we can't do cross-validation, so we'll just fit and score
31
+ clf.fit(X, y)
32
+ return clf.score(X, y)
33
+ else:
34
+ scores = cross_val_score(clf, X, y, cv=cv)
35
+ return scores.mean()
36
+
37
+
38
+ def analyze_predictive_power(df, columns=None, cv=5):
39
+ if columns is None:
40
+ columns = df.columns
41
+ results = pd.DataFrame(index=columns, columns=["predictive_power", "features"])
42
+
43
+ for target_col in columns:
44
+ feature_cols = [col for col in columns if col != target_col]
45
+ try:
46
+ power = predictive_power(df, target_col, feature_cols, cv)
47
+ results.loc[target_col, "predictive_power"] = power
48
+ results.loc[target_col, "features"] = ", ".join(feature_cols)
49
+ except Exception as e:
50
+ print(f"Error processing {target_col}: {str(e)}")
51
+ results.loc[target_col, "predictive_power"] = np.nan
52
+
53
+ return results