predikit 0.4.1__tar.gz → 0.4.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {predikit-0.4.1 → predikit-0.4.2}/.claude/settings.local.json +2 -1
- {predikit-0.4.1 → predikit-0.4.2}/.gitignore +1 -0
- {predikit-0.4.1 → predikit-0.4.2}/CHANGELOG.md +14 -0
- {predikit-0.4.1 → predikit-0.4.2}/PKG-INFO +58 -28
- {predikit-0.4.1 → predikit-0.4.2}/README.md +361 -331
- predikit-0.4.2/docs/logo.gif +0 -0
- predikit-0.4.2/docs/logo.png +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/pyproject.toml +2 -2
- {predikit-0.4.1 → predikit-0.4.2}/src/predikit/__init__.py +1 -1
- {predikit-0.4.1 → predikit-0.4.2}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/.github/workflows/publish.yml +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/.github/workflows/test.yml +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/.pre-commit-config.yaml +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/CLAUDE.md +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/CONTRIBUTING.md +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/LICENSE +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/examples/01_basic_sklearn.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/examples/02_xgboost_regression.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/examples/03_orlando_real_estate.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/examples/04_confidence_routing.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/examples/05_multi_model_ensemble.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/examples/06_mlflow_loader.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/examples/07_snowflake_loader.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/src/predikit/cli.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/src/predikit/coerce.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/src/predikit/ensemble.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/src/predikit/exceptions.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/src/predikit/exporters/__init__.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/src/predikit/exporters/langchain.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/src/predikit/exporters/openai.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/src/predikit/introspect.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/src/predikit/loaders/__init__.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/src/predikit/loaders/mlflow.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/src/predikit/loaders/snowflake.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/src/predikit/registry.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/src/predikit/tool.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/tests/__init__.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/tests/test_cli.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/tests/test_coerce.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/tests/test_confidence.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/tests/test_ensemble.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/tests/test_exporters_openai.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/tests/test_introspect.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/tests/test_loaders_mlflow.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/tests/test_loaders_snowflake.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/tests/test_logging.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/tests/test_registry.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/tests/test_tool.py +0 -0
- {predikit-0.4.1 → predikit-0.4.2}/tests/test_weighted_ensemble.py +0 -0
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
"allow": [
|
|
4
4
|
"Bash(git add *)",
|
|
5
5
|
"Bash(git commit -m ' *)",
|
|
6
|
-
"Bash(python -m pytest tests/ -v --tb=short --cov=src/predikit --cov-report=term-missing)"
|
|
6
|
+
"Bash(python -m pytest tests/ -v --tb=short --cov=src/predikit --cov-report=term-missing)",
|
|
7
|
+
"WebFetch(domain:github.com)"
|
|
7
8
|
]
|
|
8
9
|
}
|
|
9
10
|
}
|
|
@@ -6,6 +6,20 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
|
|
6
6
|
|
|
7
7
|
## [Unreleased]
|
|
8
8
|
|
|
9
|
+
## [0.4.2] - 2026-06-13
|
|
10
|
+
|
|
11
|
+
### Changed
|
|
12
|
+
- Redesigned PyPI/README hero: logo, centered tagline, and badges in a unified `<p align="center">` block
|
|
13
|
+
- Tagline moved from a `##` heading to a proper descriptive paragraph
|
|
14
|
+
- Badges converted to centered HTML `<img>` links for consistent rendering on PyPI
|
|
15
|
+
- Quick code teaser repositioned directly below badges (before Table of Contents)
|
|
16
|
+
- "Field naming rule" added to Table of Contents
|
|
17
|
+
- `ainvoke()` added to `ModelTool` Core API reference table
|
|
18
|
+
- `ModelEnsemble` Core API subsection added with constructor signature and full strategy table
|
|
19
|
+
- Project Traffic / download badge moved to bottom of README
|
|
20
|
+
- Development Status classifier bumped from `3 - Alpha` to `4 - Beta` in `pyproject.toml`
|
|
21
|
+
- Removed CI test status badge from README
|
|
22
|
+
|
|
9
23
|
## [0.4.1] - 2026-06-02
|
|
10
24
|
|
|
11
25
|
### Added
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: predikit
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.2
|
|
4
4
|
Summary: Turn any trained sklearn/XGBoost model into an LLM-callable tool with auto-generated schemas and typed I/O.
|
|
5
5
|
Project-URL: Homepage, https://github.com/Tejas-TA/predikit
|
|
6
6
|
Project-URL: Repository, https://github.com/Tejas-TA/predikit
|
|
@@ -10,7 +10,7 @@ Author-email: Tejas Tumakuru Ashok <tejasta@gmail.com>
|
|
|
10
10
|
License: MIT
|
|
11
11
|
License-File: LICENSE
|
|
12
12
|
Keywords: agents,function-calling,llm,ml-tools,sklearn,xgboost
|
|
13
|
-
Classifier: Development Status ::
|
|
13
|
+
Classifier: Development Status :: 4 - Beta
|
|
14
14
|
Classifier: Intended Audience :: Developers
|
|
15
15
|
Classifier: Intended Audience :: Science/Research
|
|
16
16
|
Classifier: License :: OSI Approved :: MIT License
|
|
@@ -48,34 +48,46 @@ Provides-Extra: xgboost
|
|
|
48
48
|
Requires-Dist: xgboost>=1.7; extra == 'xgboost'
|
|
49
49
|
Description-Content-Type: text/markdown
|
|
50
50
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
51
|
+
<p align="center">
|
|
52
|
+
<picture>
|
|
53
|
+
<source srcset="https://raw.githubusercontent.com/Tejas-TA/predikit/main/docs/logo.gif">
|
|
54
|
+
<img src="https://raw.githubusercontent.com/Tejas-TA/predikit/main/docs/logo.png" alt="predikit" width="500"/>
|
|
55
|
+
</picture>
|
|
56
|
+
</p>
|
|
57
|
+
|
|
58
|
+
<p align="center">
|
|
59
|
+
Turn any trained scikit-learn or XGBoost model into an LLM-callable tool —<br/>
|
|
60
|
+
auto-generated JSON schemas, typed I/O, zero boilerplate.
|
|
61
|
+
</p>
|
|
62
|
+
|
|
63
|
+
<p align="center">
|
|
64
|
+
<a href="https://pypi.org/project/predikit/"><img src="https://img.shields.io/pypi/v/predikit.svg" alt="PyPI version"/></a>
|
|
65
|
+
<a href="https://www.python.org/"><img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="Python 3.10+"/></a>
|
|
66
|
+
<a href="LICENSE"><img src="https://img.shields.io/badge/License-MIT-green.svg" alt="License: MIT"/></a>
|
|
67
|
+
<a href="https://github.com/astral-sh/ruff"><img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json" alt="Ruff"/></a>
|
|
68
|
+
</p>
|
|
69
|
+
|
|
70
|
+
<p align="center">
|
|
71
|
+
<a href="https://pepy.tech/project/predikit"><img src="https://static.pepy.tech/personalized-badge/predikit?period=week&units=international_system&left_color=grey&right_color=blue&left_text=weekly+downloads" alt="Weekly Downloads"/></a>
|
|
72
|
+
<a href="https://pepy.tech/project/predikit"><img src="https://static.pepy.tech/personalized-badge/predikit?period=month&units=international_system&left_color=grey&right_color=blue&left_text=monthly+downloads" alt="Monthly Downloads"/></a>
|
|
73
|
+
<a href="https://pepy.tech/project/predikit"><img src="https://static.pepy.tech/personalized-badge/predikit?period=total&units=international_system&left_color=grey&right_color=blue&left_text=total+downloads" alt="Total Downloads"/></a>
|
|
74
|
+
</p>
|
|
57
75
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
76
|
+
```python
|
|
77
|
+
tool = ModelTool(model=clf, name="classify_iris", ...)
|
|
78
|
+
tool.to_openai() # OpenAI function schema, ready to pass to the API
|
|
79
|
+
tool.invoke({"sqft": 2200}) # → {"price_usd": 370730}
|
|
80
|
+
```
|
|
62
81
|
|
|
63
82
|
## Table of Contents
|
|
64
83
|
- [Install](#install)
|
|
65
84
|
- [30-second example](#30-second-example)
|
|
66
85
|
- [Core API](#core-api)
|
|
86
|
+
- [Field naming rule](#field-naming-rule)
|
|
67
87
|
- [Cookbook](#cookbook)
|
|
68
88
|
- [Contributing](#contributing)
|
|
69
89
|
- [License](#license)
|
|
70
90
|
|
|
71
|
-
## Turn any trained scikit-learn or XGBoost model into an LLM-callable tool — auto-generated JSON schemas, typed I/O, zero boilerplate.
|
|
72
|
-
|
|
73
|
-
```python
|
|
74
|
-
tool = ModelTool(model=clf, name="classify_iris", ...)
|
|
75
|
-
tool.to_openai() # OpenAI function schema, ready to pass to the API
|
|
76
|
-
tool.invoke({"sqft": 2200}) # → {"price_usd": 370730}
|
|
77
|
-
```
|
|
78
|
-
|
|
79
91
|
## Install
|
|
80
92
|
|
|
81
93
|
```bash
|
|
@@ -153,6 +165,7 @@ ModelTool(
|
|
|
153
165
|
| Method | Returns | What it does |
|
|
154
166
|
|--------|---------|--------------|
|
|
155
167
|
| `.invoke(input_dict)` | `dict` | Validates → predicts → returns `{output_name: value}` |
|
|
168
|
+
| `.ainvoke(input_dict)` | `dict` | Async version of `.invoke()` |
|
|
156
169
|
| `.to_openai()` | `dict` | OpenAI function-calling schema |
|
|
157
170
|
| `.to_langchain()` | `StructuredTool` | LangChain tool |
|
|
158
171
|
| `.to_callable()` | `Callable` | Plain Python function |
|
|
@@ -168,6 +181,30 @@ registry.to_langchain() # → list[StructuredTool]
|
|
|
168
181
|
registry.get("name") # → ModelTool
|
|
169
182
|
```
|
|
170
183
|
|
|
184
|
+
### `ModelEnsemble`
|
|
185
|
+
|
|
186
|
+
Call multiple models and reconcile their outputs in one step:
|
|
187
|
+
|
|
188
|
+
```python
|
|
189
|
+
ModelEnsemble(
|
|
190
|
+
tools: list[ModelTool], # models to run in parallel
|
|
191
|
+
name: str, # ensemble tool name the LLM sees
|
|
192
|
+
description: str,
|
|
193
|
+
strategy: str, # "collect" | "mean" | "vote" | "weighted_mean" | "weighted_vote"
|
|
194
|
+
weights: list[float], # optional, for weighted strategies
|
|
195
|
+
)
|
|
196
|
+
```
|
|
197
|
+
|
|
198
|
+
| Strategy | Behaviour |
|
|
199
|
+
|----------|-----------|
|
|
200
|
+
| `"collect"` | Merges all outputs into one dict (tools can have different `output_name`) |
|
|
201
|
+
| `"mean"` | Averages numeric outputs (all tools must share `output_name`) |
|
|
202
|
+
| `"vote"` | Majority class vote (all tools must share `output_name`) |
|
|
203
|
+
| `"weighted_mean"` | Weighted average — provide a `weights` list |
|
|
204
|
+
| `"weighted_vote"` | Weighted majority vote — provide a `weights` list |
|
|
205
|
+
|
|
206
|
+
`ModelEnsemble` exposes the same `.invoke()`, `.ainvoke()`, `.to_openai()`, and `.to_langchain()` interface as `ModelTool`.
|
|
207
|
+
|
|
171
208
|
## Field naming rule
|
|
172
209
|
|
|
173
210
|
**Your Pydantic schema field names must exactly match the column names the model was trained on.**
|
|
@@ -287,8 +324,6 @@ Only applies to classifiers that implement `predict_proba`. Regressors are unaff
|
|
|
287
324
|
|
|
288
325
|
### Multi-model ensemble
|
|
289
326
|
|
|
290
|
-
Call multiple models and reconcile their outputs in one step:
|
|
291
|
-
|
|
292
327
|
```python
|
|
293
328
|
from predikit import ModelEnsemble, ToolRegistry
|
|
294
329
|
|
|
@@ -303,12 +338,6 @@ result = ensemble.invoke(inputs) # → {"price_usd": 370112}
|
|
|
303
338
|
schema = ensemble.to_openai() # works exactly like ModelTool
|
|
304
339
|
```
|
|
305
340
|
|
|
306
|
-
| strategy | behaviour |
|
|
307
|
-
|----------|-----------|
|
|
308
|
-
| `"collect"` | merges all outputs into one dict (tools can have different `output_name`) |
|
|
309
|
-
| `"mean"` | averages numeric outputs (all tools must share `output_name`) |
|
|
310
|
-
| `"vote"` | majority class vote (all tools must share `output_name`) |
|
|
311
|
-
|
|
312
341
|
Register ensembles alongside individual tools:
|
|
313
342
|
|
|
314
343
|
```python
|
|
@@ -379,3 +408,4 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for development setup, code style, and PR
|
|
|
379
408
|
## License
|
|
380
409
|
|
|
381
410
|
MIT © Tejas Tumakuru Ashok
|
|
411
|
+
|
|
@@ -1,331 +1,361 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
-
|
|
17
|
-
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
#
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
pip install predikit
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
#
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
"
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
```python
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
```
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
)
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
```
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
1
|
+
<p align="center">
|
|
2
|
+
<picture>
|
|
3
|
+
<source srcset="https://raw.githubusercontent.com/Tejas-TA/predikit/main/docs/logo.gif">
|
|
4
|
+
<img src="https://raw.githubusercontent.com/Tejas-TA/predikit/main/docs/logo.png" alt="predikit" width="500"/>
|
|
5
|
+
</picture>
|
|
6
|
+
</p>
|
|
7
|
+
|
|
8
|
+
<p align="center">
|
|
9
|
+
Turn any trained scikit-learn or XGBoost model into an LLM-callable tool —<br/>
|
|
10
|
+
auto-generated JSON schemas, typed I/O, zero boilerplate.
|
|
11
|
+
</p>
|
|
12
|
+
|
|
13
|
+
<p align="center">
|
|
14
|
+
<a href="https://pypi.org/project/predikit/"><img src="https://img.shields.io/pypi/v/predikit.svg" alt="PyPI version"/></a>
|
|
15
|
+
<a href="https://www.python.org/"><img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="Python 3.10+"/></a>
|
|
16
|
+
<a href="LICENSE"><img src="https://img.shields.io/badge/License-MIT-green.svg" alt="License: MIT"/></a>
|
|
17
|
+
<a href="https://github.com/astral-sh/ruff"><img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json" alt="Ruff"/></a>
|
|
18
|
+
</p>
|
|
19
|
+
|
|
20
|
+
<p align="center">
|
|
21
|
+
<a href="https://pepy.tech/project/predikit"><img src="https://static.pepy.tech/personalized-badge/predikit?period=week&units=international_system&left_color=grey&right_color=blue&left_text=weekly+downloads" alt="Weekly Downloads"/></a>
|
|
22
|
+
<a href="https://pepy.tech/project/predikit"><img src="https://static.pepy.tech/personalized-badge/predikit?period=month&units=international_system&left_color=grey&right_color=blue&left_text=monthly+downloads" alt="Monthly Downloads"/></a>
|
|
23
|
+
<a href="https://pepy.tech/project/predikit"><img src="https://static.pepy.tech/personalized-badge/predikit?period=total&units=international_system&left_color=grey&right_color=blue&left_text=total+downloads" alt="Total Downloads"/></a>
|
|
24
|
+
</p>
|
|
25
|
+
|
|
26
|
+
```python
|
|
27
|
+
tool = ModelTool(model=clf, name="classify_iris", ...)
|
|
28
|
+
tool.to_openai() # OpenAI function schema, ready to pass to the API
|
|
29
|
+
tool.invoke({"sqft": 2200}) # → {"price_usd": 370730}
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
## Table of Contents
|
|
33
|
+
- [Install](#install)
|
|
34
|
+
- [30-second example](#30-second-example)
|
|
35
|
+
- [Core API](#core-api)
|
|
36
|
+
- [Field naming rule](#field-naming-rule)
|
|
37
|
+
- [Cookbook](#cookbook)
|
|
38
|
+
- [Contributing](#contributing)
|
|
39
|
+
- [License](#license)
|
|
40
|
+
|
|
41
|
+
## Install
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
pip install predikit
|
|
45
|
+
|
|
46
|
+
# With XGBoost support
|
|
47
|
+
pip install predikit[xgboost]
|
|
48
|
+
|
|
49
|
+
# With LangChain support
|
|
50
|
+
pip install predikit[langchain]
|
|
51
|
+
|
|
52
|
+
# With MLflow Model Registry support
|
|
53
|
+
pip install predikit[mlflow]
|
|
54
|
+
|
|
55
|
+
# With Snowflake Model Registry support
|
|
56
|
+
pip install predikit[snowflake]
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
## 30-second example
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
from pydantic import BaseModel, Field
|
|
63
|
+
from sklearn.datasets import load_iris
|
|
64
|
+
from sklearn.linear_model import LogisticRegression
|
|
65
|
+
from predikit import ModelTool
|
|
66
|
+
|
|
67
|
+
# Train
|
|
68
|
+
X, y = load_iris(return_X_y=True)
|
|
69
|
+
clf = LogisticRegression(max_iter=200).fit(X, y)
|
|
70
|
+
|
|
71
|
+
# Define what the LLM will pass in
|
|
72
|
+
class IrisInput(BaseModel):
|
|
73
|
+
sepal_length: float = Field(description="Sepal length in cm")
|
|
74
|
+
sepal_width: float = Field(description="Sepal width in cm")
|
|
75
|
+
petal_length: float = Field(description="Petal length in cm")
|
|
76
|
+
petal_width: float = Field(description="Petal width in cm")
|
|
77
|
+
|
|
78
|
+
# Wrap the model
|
|
79
|
+
tool = ModelTool(
|
|
80
|
+
model=clf,
|
|
81
|
+
name="classify_iris",
|
|
82
|
+
description="Classify an iris flower: 0=setosa, 1=versicolor, 2=virginica.",
|
|
83
|
+
input_schema=IrisInput,
|
|
84
|
+
output_name="species",
|
|
85
|
+
output_description="Predicted species index",
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Get an OpenAI-ready schema
|
|
89
|
+
import json
|
|
90
|
+
print(json.dumps(tool.to_openai(), indent=2))
|
|
91
|
+
|
|
92
|
+
# Call it directly
|
|
93
|
+
tool.invoke({
|
|
94
|
+
"sepal_length": 5.1, "sepal_width": 3.5,
|
|
95
|
+
"petal_length": 1.4, "petal_width": 0.2,
|
|
96
|
+
})
|
|
97
|
+
# → {"species": 0}
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
## Core API
|
|
101
|
+
|
|
102
|
+
### `ModelTool`
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
ModelTool(
|
|
106
|
+
model, # fitted sklearn-compatible estimator
|
|
107
|
+
name: str, # tool name the LLM sees
|
|
108
|
+
description: str, # tool description the LLM sees
|
|
109
|
+
input_schema, # Pydantic BaseModel describing inputs
|
|
110
|
+
output_name: str, # key for the prediction in the returned dict
|
|
111
|
+
output_description: str,
|
|
112
|
+
)
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
| Method | Returns | What it does |
|
|
116
|
+
|--------|---------|--------------|
|
|
117
|
+
| `.invoke(input_dict)` | `dict` | Validates → predicts → returns `{output_name: value}` |
|
|
118
|
+
| `.ainvoke(input_dict)` | `dict` | Async version of `.invoke()` |
|
|
119
|
+
| `.to_openai()` | `dict` | OpenAI function-calling schema |
|
|
120
|
+
| `.to_langchain()` | `StructuredTool` | LangChain tool |
|
|
121
|
+
| `.to_callable()` | `Callable` | Plain Python function |
|
|
122
|
+
|
|
123
|
+
### `ToolRegistry`
|
|
124
|
+
|
|
125
|
+
Group multiple tools for bulk export:
|
|
126
|
+
|
|
127
|
+
```python
|
|
128
|
+
registry = ToolRegistry([price_tool, risk_tool])
|
|
129
|
+
registry.to_openai() # → list[dict], pass directly to OpenAI
|
|
130
|
+
registry.to_langchain() # → list[StructuredTool]
|
|
131
|
+
registry.get("name") # → ModelTool
|
|
132
|
+
```
|
|
133
|
+
|
|
134
|
+
### `ModelEnsemble`
|
|
135
|
+
|
|
136
|
+
Call multiple models and reconcile their outputs in one step:
|
|
137
|
+
|
|
138
|
+
```python
|
|
139
|
+
ModelEnsemble(
|
|
140
|
+
tools: list[ModelTool], # models to run in parallel
|
|
141
|
+
name: str, # ensemble tool name the LLM sees
|
|
142
|
+
description: str,
|
|
143
|
+
strategy: str, # "collect" | "mean" | "vote" | "weighted_mean" | "weighted_vote"
|
|
144
|
+
weights: list[float], # optional, for weighted strategies
|
|
145
|
+
)
|
|
146
|
+
```
|
|
147
|
+
|
|
148
|
+
| Strategy | Behaviour |
|
|
149
|
+
|----------|-----------|
|
|
150
|
+
| `"collect"` | Merges all outputs into one dict (tools can have different `output_name`) |
|
|
151
|
+
| `"mean"` | Averages numeric outputs (all tools must share `output_name`) |
|
|
152
|
+
| `"vote"` | Majority class vote (all tools must share `output_name`) |
|
|
153
|
+
| `"weighted_mean"` | Weighted average — provide a `weights` list |
|
|
154
|
+
| `"weighted_vote"` | Weighted majority vote — provide a `weights` list |
|
|
155
|
+
|
|
156
|
+
`ModelEnsemble` exposes the same `.invoke()`, `.ainvoke()`, `.to_openai()`, and `.to_langchain()` interface as `ModelTool`.
|
|
157
|
+
|
|
158
|
+
## Field naming rule
|
|
159
|
+
|
|
160
|
+
**Your Pydantic schema field names must exactly match the column names the model was trained on.**
|
|
161
|
+
|
|
162
|
+
predikit maps inputs to features by name, not position. If you trained on a DataFrame with columns `["sqft", "bedrooms"]`, your schema fields must be `sqft` and `bedrooms` — not `sq_ft`, not `Sqft`.
|
|
163
|
+
|
|
164
|
+
```python
|
|
165
|
+
# ✓ Columns match: sqft, bedrooms, bathrooms
|
|
166
|
+
class GoodInput(BaseModel):
|
|
167
|
+
sqft: float
|
|
168
|
+
bedrooms: float
|
|
169
|
+
bathrooms: float
|
|
170
|
+
|
|
171
|
+
# ✗ Name mismatch — raises ValueError at runtime
|
|
172
|
+
class BadInput(BaseModel):
|
|
173
|
+
square_footage: float # model expects "sqft"
|
|
174
|
+
beds: float # model expects "bedrooms"
|
|
175
|
+
baths: float # model expects "bathrooms"
|
|
176
|
+
```
|
|
177
|
+
|
|
178
|
+
When there's a mismatch, predikit tells you exactly which names are wrong:
|
|
179
|
+
|
|
180
|
+
```
|
|
181
|
+
ValueError: Input schema is missing model features: ['sqft', 'bedrooms'].
|
|
182
|
+
Schema has: ['square_footage', 'beds', 'bathrooms'], model expects: ['sqft', 'bedrooms', 'bathrooms']
|
|
183
|
+
```
|
|
184
|
+
|
|
185
|
+
> **Tip:** If you trained with a numpy array (no DataFrame), predikit has no feature names to check — it uses your schema's field definition order instead.
|
|
186
|
+
|
|
187
|
+
## Cookbook
|
|
188
|
+
|
|
189
|
+
### XGBoost regression
|
|
190
|
+
|
|
191
|
+
```python
|
|
192
|
+
from xgboost import XGBRegressor
|
|
193
|
+
from predikit import ModelTool
|
|
194
|
+
|
|
195
|
+
reg = XGBRegressor().fit(X_train, y_train)
|
|
196
|
+
|
|
197
|
+
class HouseInput(BaseModel):
|
|
198
|
+
sqft: float
|
|
199
|
+
bedrooms: float
|
|
200
|
+
year_built: float
|
|
201
|
+
|
|
202
|
+
tool = ModelTool(
|
|
203
|
+
model=reg,
|
|
204
|
+
name="price_estimate",
|
|
205
|
+
description="Predict home price in USD.",
|
|
206
|
+
input_schema=HouseInput,
|
|
207
|
+
output_name="price_usd",
|
|
208
|
+
output_description="Predicted sale price in USD",
|
|
209
|
+
)
|
|
210
|
+
```
|
|
211
|
+
|
|
212
|
+
### Multiple tools in one registry
|
|
213
|
+
|
|
214
|
+
```python
|
|
215
|
+
registry = ToolRegistry([price_tool, risk_tool, demand_tool])
|
|
216
|
+
|
|
217
|
+
# OpenAI
|
|
218
|
+
response = client.chat.completions.create(
|
|
219
|
+
model="gpt-4o",
|
|
220
|
+
tools=registry.to_openai(),
|
|
221
|
+
...
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# LangChain
|
|
225
|
+
agent = initialize_agent(tools=registry.to_langchain(), ...)
|
|
226
|
+
```
|
|
227
|
+
|
|
228
|
+
### Bool inputs from an LLM
|
|
229
|
+
|
|
230
|
+
LLMs sometimes return `"yes"`, `"true"`, or `"1"` for boolean fields. predikit coerces these automatically before Pydantic validation:
|
|
231
|
+
|
|
232
|
+
```python
|
|
233
|
+
class Input(BaseModel):
|
|
234
|
+
has_pool: bool
|
|
235
|
+
|
|
236
|
+
tool.invoke({"has_pool": "yes"}) # → coerced to True
|
|
237
|
+
tool.invoke({"has_pool": "false"}) # → coerced to False
|
|
238
|
+
tool.invoke({"has_pool": "maybe"}) # → raises ValueError with clear message
|
|
239
|
+
```
|
|
240
|
+
|
|
241
|
+
Supported strings: `true/false`, `yes/no`, `1/0`, `on/off`.
|
|
242
|
+
|
|
243
|
+
### Confidence-aware routing
|
|
244
|
+
|
|
245
|
+
Route uncertain predictions to a fallback tool, or raise an error the agent can catch:
|
|
246
|
+
|
|
247
|
+
```python
|
|
248
|
+
from predikit import ModelTool, LowConfidenceError
|
|
249
|
+
|
|
250
|
+
tool = ModelTool(
|
|
251
|
+
model=clf,
|
|
252
|
+
name="churn_risk",
|
|
253
|
+
description="Predict member churn risk.",
|
|
254
|
+
input_schema=MemberInput,
|
|
255
|
+
output_name="churn_probability",
|
|
256
|
+
output_description="Probability of churn (0–1)",
|
|
257
|
+
confidence_threshold=0.80, # classifiers with predict_proba only
|
|
258
|
+
on_low_confidence="warn", # "warn" | "raise" | "fallback"
|
|
259
|
+
fallback_tool=rule_based_tool, # used when mode="fallback"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
result = tool.invoke(inputs)
|
|
263
|
+
if result.get("_low_confidence"):
|
|
264
|
+
print(f"Uncertain ({result['_confidence']:.2f}) — consider routing to a human")
|
|
265
|
+
```
|
|
266
|
+
|
|
267
|
+
| mode | behaviour |
|
|
268
|
+
|------|-----------|
|
|
269
|
+
| `"warn"` | returns prediction + `_confidence` + `_low_confidence: True` |
|
|
270
|
+
| `"raise"` | raises `LowConfidenceError` |
|
|
271
|
+
| `"fallback"` | invokes `fallback_tool` and returns its result |
|
|
272
|
+
|
|
273
|
+
Only applies to classifiers that implement `predict_proba`. Regressors are unaffected.
|
|
274
|
+
|
|
275
|
+
### Multi-model ensemble
|
|
276
|
+
|
|
277
|
+
```python
|
|
278
|
+
from predikit import ModelEnsemble, ToolRegistry
|
|
279
|
+
|
|
280
|
+
ensemble = ModelEnsemble(
|
|
281
|
+
tools=[price_tool_a, price_tool_b],
|
|
282
|
+
name="averaged_price",
|
|
283
|
+
description="Ensemble price: mean of two XGBoost models.",
|
|
284
|
+
strategy="mean", # "collect" | "mean" | "vote"
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
result = ensemble.invoke(inputs) # → {"price_usd": 370112}
|
|
288
|
+
schema = ensemble.to_openai() # works exactly like ModelTool
|
|
289
|
+
```
|
|
290
|
+
|
|
291
|
+
Register ensembles alongside individual tools:
|
|
292
|
+
|
|
293
|
+
```python
|
|
294
|
+
registry = ToolRegistry(tools=[price_tool], ensembles=[ensemble])
|
|
295
|
+
registry.to_openai() # includes both tools and ensembles
|
|
296
|
+
```
|
|
297
|
+
|
|
298
|
+
### MLflow Model Registry loader
|
|
299
|
+
|
|
300
|
+
Load a registered MLflow model directly — no manual `.load_model()` call:
|
|
301
|
+
|
|
302
|
+
```python
|
|
303
|
+
from predikit.loaders import from_mlflow
|
|
304
|
+
|
|
305
|
+
tool = from_mlflow(
|
|
306
|
+
model_uri="models:/churn-classifier/Production",
|
|
307
|
+
name="churn_risk",
|
|
308
|
+
description="Predict member churn probability.",
|
|
309
|
+
input_schema=MemberInput,
|
|
310
|
+
output_name="churn_probability",
|
|
311
|
+
output_description="Churn probability 0–1",
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
tool.invoke({"tenure_months": 24, "trips_last_year": 2, "avg_spend": 500})
|
|
315
|
+
# → {"churn_probability": 0.73}
|
|
316
|
+
```
|
|
317
|
+
|
|
318
|
+
The loader auto-detects `classes_` and `feature_names_in_` from the underlying sklearn model, so confidence routing and ensemble work unchanged. Requires `pip install predikit[mlflow]`.
|
|
319
|
+
|
|
320
|
+
### Snowflake Model Registry loader
|
|
321
|
+
|
|
322
|
+
Load a model registered in the Snowflake Model Registry via the Snowpark ML Python library:
|
|
323
|
+
|
|
324
|
+
```python
|
|
325
|
+
from predikit.loaders import from_snowflake
|
|
326
|
+
|
|
327
|
+
tool = from_snowflake(
|
|
328
|
+
session=snowpark_session,
|
|
329
|
+
model_name="VACATION_CHURN",
|
|
330
|
+
model_version="V3",
|
|
331
|
+
name="churn_risk",
|
|
332
|
+
description="Churn classifier.",
|
|
333
|
+
input_schema=MemberInput,
|
|
334
|
+
output_name="churn_probability",
|
|
335
|
+
output_description="Churn probability 0–1",
|
|
336
|
+
output_method="predict", # method to call on the Snowflake model object
|
|
337
|
+
)
|
|
338
|
+
```
|
|
339
|
+
|
|
340
|
+
Pass `output_method="predict_proba"` or any other method your Snowflake model exposes. The returned `ModelTool` is identical to one built directly — all exporters, confidence routing, and ensemble strategies work as-is. Requires `pip install predikit[snowflake]`.
|
|
341
|
+
|
|
342
|
+
### Orlando real estate demo
|
|
343
|
+
|
|
344
|
+
See [`examples/03_orlando_real_estate.py`](examples/03_orlando_real_estate.py) for a full end-to-end walkthrough: synthetic dataset → XGBoost training → `ModelTool` → registry → OpenAI schema → prediction.
|
|
345
|
+
|
|
346
|
+
## Roadmap
|
|
347
|
+
|
|
348
|
+
Planned for later releases:
|
|
349
|
+
|
|
350
|
+
- HuggingFace / PyTorch / TensorFlow model support
|
|
351
|
+
- Streaming inference support
|
|
352
|
+
- OpenAI Assistants API integration
|
|
353
|
+
|
|
354
|
+
## Contributing
|
|
355
|
+
|
|
356
|
+
See [CONTRIBUTING.md](CONTRIBUTING.md) for development setup, code style, and PR guidelines. The [CHANGELOG](CHANGELOG.md) tracks notable changes per release.
|
|
357
|
+
|
|
358
|
+
## License
|
|
359
|
+
|
|
360
|
+
MIT © Tejas Tumakuru Ashok
|
|
361
|
+
|
|
Binary file
|
|
Binary file
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "predikit"
|
|
7
|
-
version = "0.4.
|
|
7
|
+
version = "0.4.2"
|
|
8
8
|
description = "Turn any trained sklearn/XGBoost model into an LLM-callable tool with auto-generated schemas and typed I/O."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = {text = "MIT"}
|
|
@@ -14,7 +14,7 @@ authors = [
|
|
|
14
14
|
]
|
|
15
15
|
keywords = ["llm", "agents", "sklearn", "xgboost", "function-calling", "ml-tools"]
|
|
16
16
|
classifiers = [
|
|
17
|
-
"Development Status ::
|
|
17
|
+
"Development Status :: 4 - Beta",
|
|
18
18
|
"License :: OSI Approved :: MIT License",
|
|
19
19
|
"Programming Language :: Python :: 3",
|
|
20
20
|
"Programming Language :: Python :: 3.10",
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|