weco 0.1.10__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {weco-0.1.10 → weco-0.2.0}/.github/workflows/lint.yml +11 -21
- {weco-0.1.10 → weco-0.2.0}/.github/workflows/release.yml +23 -11
- weco-0.2.0/.gitignore +71 -0
- {weco-0.1.10 → weco-0.2.0}/LICENSE +2 -1
- weco-0.2.0/PKG-INFO +129 -0
- weco-0.2.0/README.md +107 -0
- weco-0.2.0/examples/simple-mlx/evaluate.py +133 -0
- weco-0.2.0/examples/simple-mlx/metal-examples.rst +427 -0
- weco-0.2.0/examples/simple-mlx/optimize.py +26 -0
- weco-0.2.0/examples/simple-torch/evaluate.py +125 -0
- weco-0.2.0/examples/simple-torch/optimize.py +26 -0
- weco-0.2.0/pyproject.toml +55 -0
- weco-0.2.0/weco/__init__.py +4 -0
- weco-0.2.0/weco/api.py +89 -0
- weco-0.2.0/weco/cli.py +333 -0
- weco-0.2.0/weco/panels.py +359 -0
- weco-0.2.0/weco/utils.py +119 -0
- weco-0.2.0/weco.egg-info/PKG-INFO +129 -0
- weco-0.2.0/weco.egg-info/SOURCES.txt +22 -0
- weco-0.2.0/weco.egg-info/entry_points.txt +2 -0
- weco-0.2.0/weco.egg-info/requires.txt +7 -0
- weco-0.1.10/.gitignore +0 -171
- weco-0.1.10/PKG-INFO +0 -125
- weco-0.1.10/README.md +0 -98
- weco-0.1.10/assets/weco.svg +0 -91
- weco-0.1.10/examples/cookbook.ipynb +0 -411
- weco-0.1.10/pyproject.toml +0 -71
- weco-0.1.10/tests/test_asynchronous.py +0 -93
- weco-0.1.10/tests/test_batching.py +0 -82
- weco-0.1.10/tests/test_reasoning.py +0 -59
- weco-0.1.10/tests/test_synchronous.py +0 -89
- weco-0.1.10/weco/__init__.py +0 -4
- weco-0.1.10/weco/client.py +0 -586
- weco-0.1.10/weco/constants.py +0 -4
- weco-0.1.10/weco/functional.py +0 -184
- weco-0.1.10/weco/utils.py +0 -180
- weco-0.1.10/weco.egg-info/PKG-INFO +0 -125
- weco-0.1.10/weco.egg-info/SOURCES.txt +0 -22
- weco-0.1.10/weco.egg-info/requires.txt +0 -13
- {weco-0.1.10 → weco-0.2.0}/setup.cfg +0 -0
- {weco-0.1.10 → weco-0.2.0}/weco.egg-info/dependency_links.txt +0 -0
- {weco-0.1.10 → weco-0.2.0}/weco.egg-info/top_level.txt +0 -0
|
@@ -6,12 +6,8 @@ on:
|
|
|
6
6
|
- main
|
|
7
7
|
- dev
|
|
8
8
|
|
|
9
|
-
pull_request:
|
|
10
|
-
branches:
|
|
11
|
-
- main
|
|
12
|
-
|
|
13
9
|
jobs:
|
|
14
|
-
|
|
10
|
+
lint:
|
|
15
11
|
runs-on: ubuntu-latest
|
|
16
12
|
|
|
17
13
|
steps:
|
|
@@ -23,26 +19,20 @@ jobs:
|
|
|
23
19
|
- name: Set up Python
|
|
24
20
|
uses: actions/setup-python@v3
|
|
25
21
|
with:
|
|
26
|
-
python-version: "3.
|
|
22
|
+
python-version: "3.12.0"
|
|
27
23
|
|
|
28
24
|
- name: Install dependencies
|
|
29
25
|
run: |
|
|
30
26
|
python -m pip install --upgrade pip
|
|
31
|
-
pip install
|
|
27
|
+
pip install ruff
|
|
32
28
|
|
|
33
|
-
- name:
|
|
29
|
+
- name: Run linter
|
|
34
30
|
run: |
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
- name: Run black
|
|
42
|
-
run: black .
|
|
43
|
-
|
|
44
|
-
- name: Run isort
|
|
45
|
-
run: isort .
|
|
31
|
+
ruff check . --fix
|
|
32
|
+
|
|
33
|
+
- name: Run formatter
|
|
34
|
+
run: |
|
|
35
|
+
ruff format .
|
|
46
36
|
|
|
47
37
|
- name: Commit changes
|
|
48
38
|
run: |
|
|
@@ -52,6 +42,6 @@ jobs:
|
|
|
52
42
|
if git diff --exit-code --staged; then
|
|
53
43
|
echo "No changes to commit"
|
|
54
44
|
else
|
|
55
|
-
git commit -m "[
|
|
45
|
+
git commit -m "[GitHub Action] Lint and format code with Ruff"
|
|
56
46
|
git push https://${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}
|
|
57
|
-
fi
|
|
47
|
+
fi
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
name: Publish Python Package
|
|
2
2
|
|
|
3
3
|
on:
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
4
|
+
workflow_run:
|
|
5
|
+
workflows: ["Lint and Format Code"]
|
|
6
|
+
types:
|
|
7
|
+
- completed
|
|
7
8
|
|
|
8
9
|
release:
|
|
9
10
|
types: [published]
|
|
@@ -14,18 +15,25 @@ jobs:
|
|
|
14
15
|
runs-on: ubuntu-latest
|
|
15
16
|
|
|
16
17
|
steps:
|
|
17
|
-
-
|
|
18
|
+
- name: Checkout code
|
|
19
|
+
uses: actions/checkout@v4
|
|
20
|
+
with:
|
|
21
|
+
ref: ${{ github.head_ref }}
|
|
22
|
+
|
|
18
23
|
- name: Set up Python
|
|
19
24
|
uses: actions/setup-python@v5
|
|
20
25
|
with:
|
|
21
|
-
python-version: "3.
|
|
26
|
+
python-version: "3.12"
|
|
27
|
+
|
|
22
28
|
- name: Install build dependencies
|
|
23
29
|
run: >-
|
|
24
30
|
python3 -m pip install build --user
|
|
31
|
+
|
|
25
32
|
- name: Build a source distribution and a wheel
|
|
26
33
|
run: python3 -m build
|
|
34
|
+
|
|
27
35
|
- name: Store the distribution packages
|
|
28
|
-
uses: actions/upload-artifact@
|
|
36
|
+
uses: actions/upload-artifact@v4
|
|
29
37
|
with:
|
|
30
38
|
name: python-package-distributions
|
|
31
39
|
path: dist/
|
|
@@ -44,10 +52,11 @@ jobs:
|
|
|
44
52
|
|
|
45
53
|
steps:
|
|
46
54
|
- name: Download the dists
|
|
47
|
-
uses: actions/download-artifact@
|
|
55
|
+
uses: actions/download-artifact@v4
|
|
48
56
|
with:
|
|
49
57
|
name: python-package-distributions
|
|
50
58
|
path: dist/
|
|
59
|
+
|
|
51
60
|
- name: Publish to PyPI
|
|
52
61
|
uses: pypa/gh-action-pypi-publish@release/v1
|
|
53
62
|
|
|
@@ -64,24 +73,27 @@ jobs:
|
|
|
64
73
|
|
|
65
74
|
steps:
|
|
66
75
|
- name: Download dists
|
|
67
|
-
uses: actions/download-artifact@
|
|
76
|
+
uses: actions/download-artifact@v4
|
|
68
77
|
with:
|
|
69
78
|
name: python-package-distributions
|
|
70
79
|
path: dist/
|
|
80
|
+
|
|
71
81
|
- name: Sign dists with Sigstore
|
|
72
82
|
uses: sigstore/gh-action-sigstore-python@v2.1.1
|
|
73
83
|
with:
|
|
74
84
|
inputs: >-
|
|
75
85
|
./dist/*.tar.gz
|
|
76
86
|
./dist/*.whl
|
|
87
|
+
|
|
77
88
|
- name: Create GitHub Release
|
|
78
89
|
env:
|
|
79
90
|
GITHUB_TOKEN: ${{ github.token }}
|
|
80
91
|
run: >-
|
|
81
92
|
gh release create
|
|
82
|
-
'v0.
|
|
93
|
+
'v0.2.0'
|
|
83
94
|
--repo '${{ github.repository }}'
|
|
84
95
|
--notes ""
|
|
96
|
+
|
|
85
97
|
- name: Upload artifact signatures to GitHub Release
|
|
86
98
|
env:
|
|
87
99
|
GITHUB_TOKEN: ${{ github.token }}
|
|
@@ -90,5 +102,5 @@ jobs:
|
|
|
90
102
|
# sigstore-produced signatures and certificates.
|
|
91
103
|
run: >-
|
|
92
104
|
gh release upload
|
|
93
|
-
'v0.
|
|
94
|
-
--repo '${{ github.repository }}'
|
|
105
|
+
'v0.2.0' dist/**
|
|
106
|
+
--repo '${{ github.repository }}'
|
weco-0.2.0/.gitignore
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# Python
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
*.so
|
|
6
|
+
.Python
|
|
7
|
+
env/
|
|
8
|
+
build/
|
|
9
|
+
develop-eggs/
|
|
10
|
+
dist/
|
|
11
|
+
downloads/
|
|
12
|
+
eggs/
|
|
13
|
+
.eggs/
|
|
14
|
+
lib/
|
|
15
|
+
lib64/
|
|
16
|
+
parts/
|
|
17
|
+
sdist/
|
|
18
|
+
var/
|
|
19
|
+
wheels/
|
|
20
|
+
*.egg-info/
|
|
21
|
+
.installed.cfg
|
|
22
|
+
*.egg
|
|
23
|
+
.pytest_cache/
|
|
24
|
+
.coverage
|
|
25
|
+
htmlcov/
|
|
26
|
+
.env
|
|
27
|
+
.venv
|
|
28
|
+
venv/
|
|
29
|
+
ENV/
|
|
30
|
+
|
|
31
|
+
# VSCode Extension
|
|
32
|
+
node_modules/
|
|
33
|
+
npm-debug.log*
|
|
34
|
+
yarn-debug.log*
|
|
35
|
+
yarn-error.log*
|
|
36
|
+
.npm
|
|
37
|
+
*.tsbuildinfo
|
|
38
|
+
.eslintcache
|
|
39
|
+
.next
|
|
40
|
+
out/
|
|
41
|
+
build/
|
|
42
|
+
dist/
|
|
43
|
+
|
|
44
|
+
# IDEs and editors
|
|
45
|
+
.idea/
|
|
46
|
+
.vscode/
|
|
47
|
+
*.swp
|
|
48
|
+
*.swo
|
|
49
|
+
.project
|
|
50
|
+
.classpath
|
|
51
|
+
.settings/
|
|
52
|
+
*.sublime-workspace
|
|
53
|
+
|
|
54
|
+
# OS
|
|
55
|
+
.DS_Store
|
|
56
|
+
.DS_Store?
|
|
57
|
+
._*
|
|
58
|
+
.Spotlight-V100
|
|
59
|
+
.Trashes
|
|
60
|
+
ehthumbs.db
|
|
61
|
+
Thumbs.db
|
|
62
|
+
|
|
63
|
+
# Linting
|
|
64
|
+
.ruff_cache/
|
|
65
|
+
|
|
66
|
+
# Miscellaneous
|
|
67
|
+
etc/
|
|
68
|
+
|
|
69
|
+
# AI generated files
|
|
70
|
+
digest.txt
|
|
71
|
+
.runs/
|
weco-0.2.0/PKG-INFO
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: weco
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Documentation for `weco`, a CLI for using Weco AI's code optimizer.
|
|
5
|
+
Author-email: Weco AI Team <dhruv@weco.ai>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/WecoAI/weco-cli
|
|
8
|
+
Keywords: AI,Code Optimization,Code Generation
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Requires-Python: >=3.12
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
Requires-Dist: requests
|
|
16
|
+
Requires-Dist: rich
|
|
17
|
+
Provides-Extra: dev
|
|
18
|
+
Requires-Dist: ruff; extra == "dev"
|
|
19
|
+
Requires-Dist: build; extra == "dev"
|
|
20
|
+
Requires-Dist: setuptools_scm; extra == "dev"
|
|
21
|
+
Dynamic: license-file
|
|
22
|
+
|
|
23
|
+
# Weco CLI – Optimize Your Code Effortlessly
|
|
24
|
+
|
|
25
|
+
[](https://www.python.org)
|
|
26
|
+
[](LICENSE)
|
|
27
|
+
|
|
28
|
+
`weco` is a powerful command-line interface for interacting with Weco AI's code optimizer. Whether you are looking to improve performance or refine code quality, our CLI streamlines your workflow for a better development experience.
|
|
29
|
+
|
|
30
|
+
---
|
|
31
|
+
|
|
32
|
+
## Overview
|
|
33
|
+
|
|
34
|
+
The `weco` CLI leverages advanced optimization techniques and language model strategies to iteratively improve your source code. It supports multiple language models and offers a flexible configuration to suit different optimization tasks.
|
|
35
|
+
|
|
36
|
+
---
|
|
37
|
+
|
|
38
|
+
## Setup
|
|
39
|
+
|
|
40
|
+
1. **Install the Package:**
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
pip install weco
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
2. **Configure API Keys:**
|
|
47
|
+
|
|
48
|
+
Set the appropriate environment variables for your language model provider:
|
|
49
|
+
|
|
50
|
+
- **OpenAI:** `export OPENAI_API_KEY="your_key_here"`
|
|
51
|
+
- **Anthropic:** `export ANTHROPIC_API_KEY="your_key_here"`
|
|
52
|
+
|
|
53
|
+
---
|
|
54
|
+
|
|
55
|
+
## Usage
|
|
56
|
+
|
|
57
|
+
### Command Line Arguments
|
|
58
|
+
|
|
59
|
+
| Argument | Description | Required |
|
|
60
|
+
|-----------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------|----------|
|
|
61
|
+
| `--source` | Path to the Python source code that will be optimized (e.g. optimize.py). | Yes |
|
|
62
|
+
| `--eval-command` | Command to run for evaluation (e.g. 'python eval.py --arg1=val1'). | Yes |
|
|
63
|
+
| `--metric` | Metric to optimize. | Yes |
|
|
64
|
+
| `--maximize` | Boolean flag indicating whether to maximize the metric. | Yes |
|
|
65
|
+
| `--steps` | Number of optimization steps to run. | Yes |
|
|
66
|
+
| `--model` | Model to use for optimization. | Yes |
|
|
67
|
+
| `--additional-instructions` | (Optional) Description of additional instructions or path to a file containing additional instructions. | No |
|
|
68
|
+
|
|
69
|
+
---
|
|
70
|
+
|
|
71
|
+
### Example
|
|
72
|
+
|
|
73
|
+
Optimizing common operations in pytorch:
|
|
74
|
+
```bash
|
|
75
|
+
weco --source examples/simple-torch/optimize.py \
|
|
76
|
+
--eval-command "python examples/simple-torch/evaluate.py --solution-path examples/simple-torch/optimize.py --device mps" \
|
|
77
|
+
--metric "speedup" \
|
|
78
|
+
--maximize true \
|
|
79
|
+
--steps 15 \
|
|
80
|
+
--model "o3-mini" \
|
|
81
|
+
--additional-instructions "Fuse operations in the forward method while ensuring the max float deviation remains small. Maintain the same format of the code."
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
Optimizing these same using mlx and metal:
|
|
85
|
+
```bash
|
|
86
|
+
weco --source examples/simple-mlx/optimize.py \
|
|
87
|
+
--eval-command "python examples/simple-mlx/evaluate.py --solution-path examples/simple-mlx/optimize.py" \
|
|
88
|
+
--metric "speedup" \
|
|
89
|
+
--maximize true \
|
|
90
|
+
--steps 30 \
|
|
91
|
+
--model "o3-mini" \
|
|
92
|
+
--additional-instructions "examples/simple-mlx/metal-examples.rst"
|
|
93
|
+
```
|
|
94
|
+
---
|
|
95
|
+
|
|
96
|
+
## Supported Providers
|
|
97
|
+
|
|
98
|
+
The CLI supports the following model providers:
|
|
99
|
+
|
|
100
|
+
- **OpenAI:** Set your API key using `OPENAI_API_KEY`.
|
|
101
|
+
- **Anthropic:** Set your API key using `ANTHROPIC_API_KEY`.
|
|
102
|
+
|
|
103
|
+
---
|
|
104
|
+
|
|
105
|
+
## Contributing
|
|
106
|
+
|
|
107
|
+
We welcome contributions! To get started:
|
|
108
|
+
|
|
109
|
+
1. **Fork and Clone the Repository:**
|
|
110
|
+
```bash
|
|
111
|
+
git clone https://github.com/WecoAI/weco-cli.git
|
|
112
|
+
cd weco-cli
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
2. **Install Development Dependencies:**
|
|
116
|
+
```bash
|
|
117
|
+
pip install -e ".[dev]"
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
3. **Create a Feature Branch:**
|
|
121
|
+
```bash
|
|
122
|
+
git checkout -b feature/your-feature-name
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
4. **Make Your Changes:** Ensure your code adheres to our style guidelines and includes relevant tests.
|
|
126
|
+
|
|
127
|
+
5. **Commit and Push** your changes, then open a pull request with a clear description of your enhancements.
|
|
128
|
+
|
|
129
|
+
---
|
weco-0.2.0/README.md
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# Weco CLI – Optimize Your Code Effortlessly
|
|
2
|
+
|
|
3
|
+
[](https://www.python.org)
|
|
4
|
+
[](LICENSE)
|
|
5
|
+
|
|
6
|
+
`weco` is a powerful command-line interface for interacting with Weco AI's code optimizer. Whether you are looking to improve performance or refine code quality, our CLI streamlines your workflow for a better development experience.
|
|
7
|
+
|
|
8
|
+
---
|
|
9
|
+
|
|
10
|
+
## Overview
|
|
11
|
+
|
|
12
|
+
The `weco` CLI leverages advanced optimization techniques and language model strategies to iteratively improve your source code. It supports multiple language models and offers a flexible configuration to suit different optimization tasks.
|
|
13
|
+
|
|
14
|
+
---
|
|
15
|
+
|
|
16
|
+
## Setup
|
|
17
|
+
|
|
18
|
+
1. **Install the Package:**
|
|
19
|
+
|
|
20
|
+
```bash
|
|
21
|
+
pip install weco
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
2. **Configure API Keys:**
|
|
25
|
+
|
|
26
|
+
Set the appropriate environment variables for your language model provider:
|
|
27
|
+
|
|
28
|
+
- **OpenAI:** `export OPENAI_API_KEY="your_key_here"`
|
|
29
|
+
- **Anthropic:** `export ANTHROPIC_API_KEY="your_key_here"`
|
|
30
|
+
|
|
31
|
+
---
|
|
32
|
+
|
|
33
|
+
## Usage
|
|
34
|
+
|
|
35
|
+
### Command Line Arguments
|
|
36
|
+
|
|
37
|
+
| Argument | Description | Required |
|
|
38
|
+
|-----------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------|----------|
|
|
39
|
+
| `--source` | Path to the Python source code that will be optimized (e.g. optimize.py). | Yes |
|
|
40
|
+
| `--eval-command` | Command to run for evaluation (e.g. 'python eval.py --arg1=val1'). | Yes |
|
|
41
|
+
| `--metric` | Metric to optimize. | Yes |
|
|
42
|
+
| `--maximize` | Boolean flag indicating whether to maximize the metric. | Yes |
|
|
43
|
+
| `--steps` | Number of optimization steps to run. | Yes |
|
|
44
|
+
| `--model` | Model to use for optimization. | Yes |
|
|
45
|
+
| `--additional-instructions` | (Optional) Description of additional instructions or path to a file containing additional instructions. | No |
|
|
46
|
+
|
|
47
|
+
---
|
|
48
|
+
|
|
49
|
+
### Example
|
|
50
|
+
|
|
51
|
+
Optimizing common operations in pytorch:
|
|
52
|
+
```bash
|
|
53
|
+
weco --source examples/simple-torch/optimize.py \
|
|
54
|
+
--eval-command "python examples/simple-torch/evaluate.py --solution-path examples/simple-torch/optimize.py --device mps" \
|
|
55
|
+
--metric "speedup" \
|
|
56
|
+
--maximize true \
|
|
57
|
+
--steps 15 \
|
|
58
|
+
--model "o3-mini" \
|
|
59
|
+
--additional-instructions "Fuse operations in the forward method while ensuring the max float deviation remains small. Maintain the same format of the code."
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
Optimizing these same using mlx and metal:
|
|
63
|
+
```bash
|
|
64
|
+
weco --source examples/simple-mlx/optimize.py \
|
|
65
|
+
--eval-command "python examples/simple-mlx/evaluate.py --solution-path examples/simple-mlx/optimize.py" \
|
|
66
|
+
--metric "speedup" \
|
|
67
|
+
--maximize true \
|
|
68
|
+
--steps 30 \
|
|
69
|
+
--model "o3-mini" \
|
|
70
|
+
--additional-instructions "examples/simple-mlx/metal-examples.rst"
|
|
71
|
+
```
|
|
72
|
+
---
|
|
73
|
+
|
|
74
|
+
## Supported Providers
|
|
75
|
+
|
|
76
|
+
The CLI supports the following model providers:
|
|
77
|
+
|
|
78
|
+
- **OpenAI:** Set your API key using `OPENAI_API_KEY`.
|
|
79
|
+
- **Anthropic:** Set your API key using `ANTHROPIC_API_KEY`.
|
|
80
|
+
|
|
81
|
+
---
|
|
82
|
+
|
|
83
|
+
## Contributing
|
|
84
|
+
|
|
85
|
+
We welcome contributions! To get started:
|
|
86
|
+
|
|
87
|
+
1. **Fork and Clone the Repository:**
|
|
88
|
+
```bash
|
|
89
|
+
git clone https://github.com/WecoAI/weco-cli.git
|
|
90
|
+
cd weco-cli
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
2. **Install Development Dependencies:**
|
|
94
|
+
```bash
|
|
95
|
+
pip install -e ".[dev]"
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
3. **Create a Feature Branch:**
|
|
99
|
+
```bash
|
|
100
|
+
git checkout -b feature/your-feature-name
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
4. **Make Your Changes:** Ensure your code adheres to our style guidelines and includes relevant tests.
|
|
104
|
+
|
|
105
|
+
5. **Commit and Push** your changes, then open a pull request with a clear description of your enhancements.
|
|
106
|
+
|
|
107
|
+
---
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import sys
|
|
3
|
+
import pathlib
|
|
4
|
+
import importlib
|
|
5
|
+
import traceback
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import mlx.nn as nn
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
########################################################
|
|
11
|
+
# Baseline
|
|
12
|
+
########################################################
|
|
13
|
+
class Model(nn.Module):
|
|
14
|
+
"""
|
|
15
|
+
Model that performs a matrix multiplication, division, summation, and scaling.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, input_size, hidden_size, scaling_factor):
|
|
19
|
+
super(Model, self).__init__()
|
|
20
|
+
self.weight = mx.random.normal(shape=(hidden_size, input_size))
|
|
21
|
+
self.scaling_factor = scaling_factor
|
|
22
|
+
|
|
23
|
+
def __call__(self, x):
|
|
24
|
+
"""
|
|
25
|
+
Args:
|
|
26
|
+
x (mx.array): Input tensor of shape (batch_size, input_size).
|
|
27
|
+
Returns:
|
|
28
|
+
mx.array: Output tensor of shape (batch_size, hidden_size).
|
|
29
|
+
"""
|
|
30
|
+
x = mx.matmul(x, mx.transpose(self.weight)) # Gemm
|
|
31
|
+
x = x / 2 # Divide
|
|
32
|
+
x = mx.sum(x, axis=1, keepdims=True) # Sum
|
|
33
|
+
x = x * self.scaling_factor # Scaling
|
|
34
|
+
return x
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
########################################################
|
|
38
|
+
# Weco Solution
|
|
39
|
+
########################################################
|
|
40
|
+
def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
|
|
41
|
+
# Clean out all old compiled extensions to prevent namespace collisions during build
|
|
42
|
+
module_path = pathlib.Path(module_path)
|
|
43
|
+
name = module_path.stem
|
|
44
|
+
spec = importlib.util.spec_from_file_location(name, module_path)
|
|
45
|
+
mod = importlib.util.module_from_spec(spec) # type: ignore
|
|
46
|
+
if add_to_sys_modules:
|
|
47
|
+
sys.modules[name] = mod
|
|
48
|
+
spec.loader.exec_module(mod) # type: ignore
|
|
49
|
+
return mod
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
########################################################
|
|
53
|
+
# Benchmark
|
|
54
|
+
########################################################
|
|
55
|
+
def get_inputs(B, N):
|
|
56
|
+
# MLX doesn't use device parameter like PyTorch, as it automatically uses Metal
|
|
57
|
+
return mx.random.normal(shape=(B, N), dtype=mx.float32)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def bench(f, inputs, n_warmup, n_rep):
|
|
61
|
+
# Warm up
|
|
62
|
+
for _ in range(n_warmup):
|
|
63
|
+
result = f(inputs)
|
|
64
|
+
mx.eval(result) # Force computation due to lazy evaluation
|
|
65
|
+
|
|
66
|
+
t_avg = 0.0
|
|
67
|
+
for _ in range(n_rep):
|
|
68
|
+
# Clear cache before timing
|
|
69
|
+
mx.metal.clear_cache()
|
|
70
|
+
|
|
71
|
+
start_time = time.time()
|
|
72
|
+
result = f(inputs)
|
|
73
|
+
mx.eval(result) # Force computation
|
|
74
|
+
mx.synchronize() # Wait for all computations to complete
|
|
75
|
+
t_avg += time.time() - start_time
|
|
76
|
+
|
|
77
|
+
t_avg /= n_rep * 1e-3
|
|
78
|
+
return t_avg
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
if __name__ == "__main__":
|
|
82
|
+
import argparse
|
|
83
|
+
|
|
84
|
+
parser = argparse.ArgumentParser()
|
|
85
|
+
parser.add_argument("--solution-path", type=str, required=True)
|
|
86
|
+
args = parser.parse_args()
|
|
87
|
+
|
|
88
|
+
# init and input parameters
|
|
89
|
+
B, N, H, S = 128, 10, 20, 1.5
|
|
90
|
+
|
|
91
|
+
# Set the default device to 0
|
|
92
|
+
mx.set_default_device(mx.gpu)
|
|
93
|
+
|
|
94
|
+
# load solution module
|
|
95
|
+
try:
|
|
96
|
+
mx.random.seed(0)
|
|
97
|
+
solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False)
|
|
98
|
+
solution_model = solution_module.Model(N, H, S)
|
|
99
|
+
assert hasattr(solution_model, "__call__")
|
|
100
|
+
except Exception:
|
|
101
|
+
print(f"Candidate module initialization failed: {traceback.format_exc()}")
|
|
102
|
+
exit(1)
|
|
103
|
+
|
|
104
|
+
mx.random.seed(0)
|
|
105
|
+
baseline_model = Model(N, H, S)
|
|
106
|
+
|
|
107
|
+
# measure correctness
|
|
108
|
+
n_correctness_trials = 10
|
|
109
|
+
max_diff_avg = 0
|
|
110
|
+
for _ in range(n_correctness_trials):
|
|
111
|
+
inputs = get_inputs(B, N)
|
|
112
|
+
baseline_output = baseline_model(inputs)
|
|
113
|
+
optimized_output = solution_model(inputs)
|
|
114
|
+
max_diff = mx.max(mx.abs(optimized_output - baseline_output))
|
|
115
|
+
mx.eval(max_diff) # Force computation
|
|
116
|
+
max_diff_avg += max_diff.item() # Convert to Python scalar
|
|
117
|
+
max_diff_avg /= n_correctness_trials
|
|
118
|
+
print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
|
|
119
|
+
|
|
120
|
+
# measure performance
|
|
121
|
+
inputs = get_inputs(B, N)
|
|
122
|
+
n_warmup = 100
|
|
123
|
+
n_rep = 500
|
|
124
|
+
|
|
125
|
+
# baseline
|
|
126
|
+
t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep)
|
|
127
|
+
print(f"baseline time: {t_avg_baseline:.2f}ms")
|
|
128
|
+
|
|
129
|
+
# optimized
|
|
130
|
+
t_avg_optimized = bench(solution_model, inputs, n_warmup, n_rep)
|
|
131
|
+
print(f"optimized time: {t_avg_optimized:.2f}ms")
|
|
132
|
+
|
|
133
|
+
print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")
|