inference-perf 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_perf/__init__.py +16 -0
- inference_perf/analysis/__init__.py +17 -0
- inference_perf/analysis/analyze.py +276 -0
- inference_perf/apis/__init__.py +26 -0
- inference_perf/apis/base.py +60 -0
- inference_perf/apis/chat.py +64 -0
- inference_perf/apis/completion.py +78 -0
- inference_perf/client/filestorage/__init__.py +20 -0
- inference_perf/client/filestorage/base.py +31 -0
- inference_perf/client/filestorage/gcs.py +55 -0
- inference_perf/client/filestorage/local.py +38 -0
- inference_perf/client/filestorage/s3.py +59 -0
- inference_perf/client/metricsclient/__init__.py +17 -0
- inference_perf/client/metricsclient/base.py +72 -0
- inference_perf/client/metricsclient/mock_client.py +28 -0
- inference_perf/client/metricsclient/prometheus_client/__init__.py +18 -0
- inference_perf/client/metricsclient/prometheus_client/base.py +284 -0
- inference_perf/client/metricsclient/prometheus_client/google_managed_prometheus_client.py +43 -0
- inference_perf/client/modelserver/__init__.py +19 -0
- inference_perf/client/modelserver/base.py +78 -0
- inference_perf/client/modelserver/mock_client.py +52 -0
- inference_perf/client/modelserver/vllm_client.py +245 -0
- inference_perf/client/requestdatacollector/__init__.py +23 -0
- inference_perf/client/requestdatacollector/base.py +31 -0
- inference_perf/client/requestdatacollector/local.py +30 -0
- inference_perf/client/requestdatacollector/multiprocess.py +56 -0
- inference_perf/config.py +193 -0
- inference_perf/datagen/__init__.py +28 -0
- inference_perf/datagen/base.py +66 -0
- inference_perf/datagen/hf_sharegpt_datagen.py +94 -0
- inference_perf/datagen/mock_datagen.py +43 -0
- inference_perf/datagen/random_datagen.py +105 -0
- inference_perf/datagen/shared_prefix_datagen.py +99 -0
- inference_perf/datagen/synthetic_datagen.py +591 -0
- inference_perf/loadgen/__init__.py +16 -0
- inference_perf/loadgen/load_generator.py +212 -0
- inference_perf/loadgen/load_timer.py +78 -0
- inference_perf/logger.py +35 -0
- inference_perf/main.py +244 -0
- inference_perf/reportgen/__init__.py +16 -0
- inference_perf/reportgen/base.py +272 -0
- inference_perf/utils/__init__.py +17 -0
- inference_perf/utils/custom_tokenizer.py +30 -0
- inference_perf/utils/distribution.py +60 -0
- inference_perf/utils/report_file.py +30 -0
- inference_perf-0.1.1.dist-info/METADATA +232 -0
- inference_perf-0.1.1.dist-info/RECORD +51 -0
- inference_perf-0.1.1.dist-info/WHEEL +5 -0
- inference_perf-0.1.1.dist-info/entry_points.txt +2 -0
- inference_perf-0.1.1.dist-info/licenses/LICENSE +201 -0
- inference_perf-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 The Kubernetes Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from .main import main_cli
|
|
15
|
+
|
|
16
|
+
__all__ = ["main_cli"]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 The Kubernetes Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .analyze import analyze_reports
|
|
16
|
+
|
|
17
|
+
__all__ = ["analyze_reports"]
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
# Copyright 2025 The Kubernetes Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
import operator
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any, Dict, List, Tuple
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _extract_latency_metric(latency_data: Dict[str, Any], metric_name: str, convert_to_ms: bool = False) -> float | None:
|
|
25
|
+
"""Helper to extract a metric's mean value from latency data."""
|
|
26
|
+
metric_data = latency_data.get(metric_name)
|
|
27
|
+
if isinstance(metric_data, dict):
|
|
28
|
+
mean_val = metric_data.get("mean")
|
|
29
|
+
if isinstance(mean_val, (int, float)):
|
|
30
|
+
return mean_val * 1000 if convert_to_ms else mean_val
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _extract_throughput_metric(throughput_data: Dict[str, Any], metric_name: str) -> float | None:
|
|
35
|
+
"""Helper to extract a throughput metric's value."""
|
|
36
|
+
metric_value = throughput_data.get(metric_name)
|
|
37
|
+
if isinstance(metric_value, (int, float)):
|
|
38
|
+
return float(metric_value)
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _generate_plot(charts_to_generate: List[Dict[str, Any]], suptitle: str, output_path: Path) -> None:
|
|
43
|
+
"""Generates and saves a plot with multiple subplots."""
|
|
44
|
+
import matplotlib.pyplot as plt
|
|
45
|
+
|
|
46
|
+
if not charts_to_generate:
|
|
47
|
+
logger.warning(f"No data available to generate chart: {output_path.name}")
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
num_charts = len(charts_to_generate)
|
|
51
|
+
fig, axes = plt.subplots(1, num_charts, figsize=(7 * num_charts, 6), squeeze=False)
|
|
52
|
+
fig.suptitle(suptitle, fontsize=16)
|
|
53
|
+
|
|
54
|
+
for i, chart_info in enumerate(charts_to_generate):
|
|
55
|
+
ax = axes[0, i]
|
|
56
|
+
data = chart_info["data"]
|
|
57
|
+
qps_values = [x[0] for x in data]
|
|
58
|
+
y_values = [x[1] for x in data]
|
|
59
|
+
|
|
60
|
+
ax.plot(qps_values, y_values, marker="o", linestyle="-")
|
|
61
|
+
ax.set_title(chart_info["title"])
|
|
62
|
+
ax.set_xlabel(chart_info.get("xlabel", "QPS (requested rate)"))
|
|
63
|
+
ax.set_ylabel(chart_info["ylabel"])
|
|
64
|
+
ax.grid(True)
|
|
65
|
+
|
|
66
|
+
fig.tight_layout(rect=(0, 0.03, 1, 0.95))
|
|
67
|
+
plt.savefig(output_path)
|
|
68
|
+
logger.info(f"Chart saved to {output_path}")
|
|
69
|
+
plt.close(fig)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def analyze_reports(report_dir: str) -> None:
|
|
73
|
+
"""
|
|
74
|
+
Analyzes performance reports to generate charts.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
report_dir: The directory containing the report files.
|
|
78
|
+
"""
|
|
79
|
+
try:
|
|
80
|
+
# Check for matplotlib and provide a helpful error message if it's not installed.
|
|
81
|
+
import matplotlib # noqa: F401
|
|
82
|
+
except ImportError:
|
|
83
|
+
logger.error(
|
|
84
|
+
"matplotlib is not installed. Please install it to use the --analyze feature.\n"
|
|
85
|
+
"You can install it via 'pip install .[analysis]'"
|
|
86
|
+
)
|
|
87
|
+
return
|
|
88
|
+
|
|
89
|
+
logger.info(f"Analyzing reports in {report_dir}")
|
|
90
|
+
|
|
91
|
+
# Find stage lifecycle metrics files
|
|
92
|
+
report_path = Path(report_dir)
|
|
93
|
+
stage_files = list(report_path.glob("stage_*_lifecycle_metrics.json"))
|
|
94
|
+
|
|
95
|
+
if not stage_files:
|
|
96
|
+
logger.error(f"No stage lifecycle metrics files found in {report_dir}")
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
# Latency data
|
|
100
|
+
qps_vs_ttft: List[Tuple[float, float]] = []
|
|
101
|
+
qps_vs_ntpot: List[Tuple[float, float]] = []
|
|
102
|
+
qps_vs_itl: List[Tuple[float, float]] = []
|
|
103
|
+
# Throughput data
|
|
104
|
+
qps_vs_itps: List[Tuple[float, float]] = []
|
|
105
|
+
qps_vs_otps: List[Tuple[float, float]] = []
|
|
106
|
+
qps_vs_ttps: List[Tuple[float, float]] = []
|
|
107
|
+
# Throughput vs Latency data
|
|
108
|
+
ttft_vs_otps: List[Tuple[float, float]] = []
|
|
109
|
+
ntpot_vs_otps: List[Tuple[float, float]] = []
|
|
110
|
+
itl_vs_otps: List[Tuple[float, float]] = []
|
|
111
|
+
|
|
112
|
+
for stage_file in stage_files:
|
|
113
|
+
try:
|
|
114
|
+
with open(stage_file, "r") as f:
|
|
115
|
+
report_data = json.load(f)
|
|
116
|
+
|
|
117
|
+
# Get QPS from report file
|
|
118
|
+
qps = report_data.get("load_summary", {}).get("requested_rate")
|
|
119
|
+
if qps is None:
|
|
120
|
+
logger.warning(f"Could not find requested_rate in {stage_file.name}. Skipping.")
|
|
121
|
+
continue
|
|
122
|
+
|
|
123
|
+
success_data = report_data.get("successes", {})
|
|
124
|
+
if not success_data:
|
|
125
|
+
logger.warning(f"No success data in {stage_file.name}. Skipping.")
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
# Extract latency metrics if they exist
|
|
129
|
+
ttft, ntpot, itl = None, None, None
|
|
130
|
+
latency_data = success_data.get("latency", {})
|
|
131
|
+
if latency_data:
|
|
132
|
+
ttft = _extract_latency_metric(latency_data, "time_to_first_token", convert_to_ms=True)
|
|
133
|
+
if ttft is not None:
|
|
134
|
+
qps_vs_ttft.append((qps, ttft))
|
|
135
|
+
|
|
136
|
+
ntpot = _extract_latency_metric(latency_data, "normalized_time_per_output_token", convert_to_ms=True)
|
|
137
|
+
if ntpot is not None:
|
|
138
|
+
qps_vs_ntpot.append((qps, ntpot))
|
|
139
|
+
|
|
140
|
+
itl = _extract_latency_metric(latency_data, "inter_token_latency", convert_to_ms=True)
|
|
141
|
+
if itl is not None:
|
|
142
|
+
qps_vs_itl.append((qps, itl))
|
|
143
|
+
|
|
144
|
+
# Extract throughput metrics if they exist
|
|
145
|
+
otps = None
|
|
146
|
+
throughput_data = success_data.get("throughput", {})
|
|
147
|
+
if throughput_data:
|
|
148
|
+
itps = _extract_throughput_metric(throughput_data, "input_tokens_per_sec")
|
|
149
|
+
if itps is not None:
|
|
150
|
+
qps_vs_itps.append((qps, itps))
|
|
151
|
+
|
|
152
|
+
otps = _extract_throughput_metric(throughput_data, "output_tokens_per_sec")
|
|
153
|
+
if otps is not None:
|
|
154
|
+
qps_vs_otps.append((qps, otps))
|
|
155
|
+
|
|
156
|
+
ttps = _extract_throughput_metric(throughput_data, "total_tokens_per_sec")
|
|
157
|
+
if ttps is not None:
|
|
158
|
+
qps_vs_ttps.append((qps, ttps))
|
|
159
|
+
|
|
160
|
+
# Populate latency vs throughput data
|
|
161
|
+
if otps is not None:
|
|
162
|
+
if ttft is not None:
|
|
163
|
+
ttft_vs_otps.append((ttft, otps))
|
|
164
|
+
if ntpot is not None:
|
|
165
|
+
ntpot_vs_otps.append((ntpot, otps))
|
|
166
|
+
if itl is not None:
|
|
167
|
+
itl_vs_otps.append((itl, otps))
|
|
168
|
+
|
|
169
|
+
except json.JSONDecodeError:
|
|
170
|
+
logger.error(f"Error decoding JSON from {stage_file.name}")
|
|
171
|
+
continue
|
|
172
|
+
except Exception as e:
|
|
173
|
+
logger.error(f"An unexpected error occurred while processing {stage_file.name}: {e}")
|
|
174
|
+
continue
|
|
175
|
+
|
|
176
|
+
# --- Generate Latency Plot ---
|
|
177
|
+
latency_charts_to_generate = []
|
|
178
|
+
if qps_vs_ttft:
|
|
179
|
+
latency_charts_to_generate.append(
|
|
180
|
+
{
|
|
181
|
+
"title": "Time to First Token vs. QPS",
|
|
182
|
+
"ylabel": "Mean TTFT (ms)",
|
|
183
|
+
"data": sorted(qps_vs_ttft, key=operator.itemgetter(0)),
|
|
184
|
+
}
|
|
185
|
+
)
|
|
186
|
+
if qps_vs_ntpot:
|
|
187
|
+
latency_charts_to_generate.append(
|
|
188
|
+
{
|
|
189
|
+
"title": "Norm. Time per Output Token vs. QPS",
|
|
190
|
+
"ylabel": "Mean Norm. Time (ms/token)",
|
|
191
|
+
"data": sorted(qps_vs_ntpot, key=operator.itemgetter(0)),
|
|
192
|
+
}
|
|
193
|
+
)
|
|
194
|
+
if qps_vs_itl:
|
|
195
|
+
latency_charts_to_generate.append(
|
|
196
|
+
{
|
|
197
|
+
"title": "Inter-Token Latency vs. QPS",
|
|
198
|
+
"ylabel": "Mean ITL (ms)",
|
|
199
|
+
"data": sorted(qps_vs_itl, key=operator.itemgetter(0)),
|
|
200
|
+
}
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
_generate_plot(
|
|
204
|
+
latency_charts_to_generate,
|
|
205
|
+
"Latency vs Request Rate",
|
|
206
|
+
report_path / "latency_vs_qps.png",
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# --- Generate Throughput Plot ---
|
|
210
|
+
throughput_charts_to_generate = []
|
|
211
|
+
if qps_vs_itps:
|
|
212
|
+
throughput_charts_to_generate.append(
|
|
213
|
+
{
|
|
214
|
+
"title": "Input Tokens/sec vs. QPS",
|
|
215
|
+
"ylabel": "Tokens/sec",
|
|
216
|
+
"data": sorted(qps_vs_itps, key=operator.itemgetter(0)),
|
|
217
|
+
}
|
|
218
|
+
)
|
|
219
|
+
if qps_vs_otps:
|
|
220
|
+
throughput_charts_to_generate.append(
|
|
221
|
+
{
|
|
222
|
+
"title": "Output Tokens/sec vs. QPS",
|
|
223
|
+
"ylabel": "Tokens/sec",
|
|
224
|
+
"data": sorted(qps_vs_otps, key=operator.itemgetter(0)),
|
|
225
|
+
}
|
|
226
|
+
)
|
|
227
|
+
if qps_vs_ttps:
|
|
228
|
+
throughput_charts_to_generate.append(
|
|
229
|
+
{
|
|
230
|
+
"title": "Total Tokens/sec vs. QPS",
|
|
231
|
+
"ylabel": "Tokens/sec",
|
|
232
|
+
"data": sorted(qps_vs_ttps, key=operator.itemgetter(0)),
|
|
233
|
+
}
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
_generate_plot(
|
|
237
|
+
throughput_charts_to_generate,
|
|
238
|
+
"Throughput vs Request Rate",
|
|
239
|
+
report_path / "throughput_vs_qps.png",
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# --- Generate Throughput vs Latency Curve Plot ---
|
|
243
|
+
throughput_latency_charts_to_generate = []
|
|
244
|
+
if ntpot_vs_otps:
|
|
245
|
+
throughput_latency_charts_to_generate.append(
|
|
246
|
+
{
|
|
247
|
+
"title": "Throughput vs. Norm. Time per Output Token",
|
|
248
|
+
"xlabel": "Mean Norm. Time (ms/token)",
|
|
249
|
+
"ylabel": "Output Tokens/sec",
|
|
250
|
+
"data": sorted(ntpot_vs_otps, key=operator.itemgetter(0)),
|
|
251
|
+
}
|
|
252
|
+
)
|
|
253
|
+
if ttft_vs_otps:
|
|
254
|
+
throughput_latency_charts_to_generate.append(
|
|
255
|
+
{
|
|
256
|
+
"title": "Throughput vs. Time to First Token",
|
|
257
|
+
"xlabel": "Mean TTFT (ms)",
|
|
258
|
+
"ylabel": "Output Tokens/sec",
|
|
259
|
+
"data": sorted(ttft_vs_otps, key=operator.itemgetter(0)),
|
|
260
|
+
}
|
|
261
|
+
)
|
|
262
|
+
if itl_vs_otps:
|
|
263
|
+
throughput_latency_charts_to_generate.append(
|
|
264
|
+
{
|
|
265
|
+
"title": "Throughput vs. Inter-Token Latency",
|
|
266
|
+
"xlabel": "Mean ITL (ms)",
|
|
267
|
+
"ylabel": "Output Tokens/sec",
|
|
268
|
+
"data": sorted(itl_vs_otps, key=operator.itemgetter(0)),
|
|
269
|
+
}
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
_generate_plot(
|
|
273
|
+
throughput_latency_charts_to_generate,
|
|
274
|
+
"Latency vs Throughput",
|
|
275
|
+
report_path / "throughput_vs_latency.png",
|
|
276
|
+
)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright 2025 The Kubernetes Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from .base import InferenceAPIData, InferenceInfo, RequestLifecycleMetric, ErrorResponseInfo
|
|
15
|
+
from .chat import ChatCompletionAPIData, ChatMessage
|
|
16
|
+
from .completion import CompletionAPIData
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"InferenceAPIData",
|
|
20
|
+
"InferenceInfo",
|
|
21
|
+
"RequestLifecycleMetric",
|
|
22
|
+
"ErrorResponseInfo",
|
|
23
|
+
"ChatCompletionAPIData",
|
|
24
|
+
"ChatMessage",
|
|
25
|
+
"CompletionAPIData",
|
|
26
|
+
]
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright 2025 The Kubernetes Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from abc import abstractmethod
|
|
16
|
+
from typing import Any, List, Optional
|
|
17
|
+
from aiohttp import ClientResponse
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from inference_perf.utils.custom_tokenizer import CustomTokenizer
|
|
20
|
+
from inference_perf.config import APIConfig, APIType
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class InferenceInfo(BaseModel):
|
|
24
|
+
input_tokens: int = 0
|
|
25
|
+
output_tokens: int = 0
|
|
26
|
+
output_token_times: List[float] = []
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ErrorResponseInfo(BaseModel):
|
|
30
|
+
error_type: str
|
|
31
|
+
error_msg: str
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class RequestLifecycleMetric(BaseModel):
|
|
35
|
+
stage_id: Optional[int] = None
|
|
36
|
+
scheduled_time: float
|
|
37
|
+
start_time: float
|
|
38
|
+
end_time: float
|
|
39
|
+
request_data: str
|
|
40
|
+
response_data: Optional[str] = None
|
|
41
|
+
info: InferenceInfo
|
|
42
|
+
error: Optional[ErrorResponseInfo]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class InferenceAPIData(BaseModel):
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def get_api_type(self) -> APIType:
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def get_route(self) -> str:
|
|
52
|
+
raise NotImplementedError
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def to_payload(self, model_name: str, max_tokens: int, ignore_eos: bool, streaming: bool) -> dict[str, Any]:
|
|
56
|
+
raise NotImplementedError
|
|
57
|
+
|
|
58
|
+
@abstractmethod
|
|
59
|
+
async def process_response(self, response: ClientResponse, config: APIConfig, tokenizer: CustomTokenizer) -> InferenceInfo:
|
|
60
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright 2025 The Kubernetes Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, List
|
|
16
|
+
from aiohttp import ClientResponse
|
|
17
|
+
from pydantic import BaseModel
|
|
18
|
+
from inference_perf.apis import InferenceAPIData, InferenceInfo
|
|
19
|
+
from inference_perf.utils.custom_tokenizer import CustomTokenizer
|
|
20
|
+
from inference_perf.config import APIConfig, APIType
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ChatMessage(BaseModel):
|
|
24
|
+
role: str
|
|
25
|
+
content: str
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ChatCompletionAPIData(InferenceAPIData):
|
|
29
|
+
messages: List[ChatMessage]
|
|
30
|
+
max_tokens: int = 0
|
|
31
|
+
|
|
32
|
+
def get_api_type(self) -> APIType:
|
|
33
|
+
return APIType.Chat
|
|
34
|
+
|
|
35
|
+
def get_route(self) -> str:
|
|
36
|
+
return "/v1/chat/completions"
|
|
37
|
+
|
|
38
|
+
def to_payload(self, model_name: str, max_tokens: int, ignore_eos: bool, streaming: bool) -> dict[str, Any]:
|
|
39
|
+
if streaming:
|
|
40
|
+
raise Exception("Generating streaming request payloads for the Chat API is not currently supported.")
|
|
41
|
+
if self.max_tokens == 0:
|
|
42
|
+
self.max_tokens = max_tokens
|
|
43
|
+
return {
|
|
44
|
+
"model": model_name,
|
|
45
|
+
"messages": [{"role": m.role, "content": m.content} for m in self.messages],
|
|
46
|
+
"max_tokens": self.max_tokens,
|
|
47
|
+
"ignore_eos": ignore_eos,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
async def process_response(self, response: ClientResponse, config: APIConfig, tokenizer: CustomTokenizer) -> InferenceInfo:
|
|
51
|
+
if config.streaming:
|
|
52
|
+
raise Exception("Decoding streamed responses from the Chat API is not currently supported")
|
|
53
|
+
else:
|
|
54
|
+
data = await response.json()
|
|
55
|
+
prompt_len = tokenizer.count_tokens("".join([m.content for m in self.messages]))
|
|
56
|
+
choices = data.get("choices", [])
|
|
57
|
+
if len(choices) == 0:
|
|
58
|
+
return InferenceInfo(input_tokens=prompt_len)
|
|
59
|
+
output_text = "".join([choice.get("message", {}).get("content", "") for choice in choices])
|
|
60
|
+
output_len = tokenizer.count_tokens(output_text)
|
|
61
|
+
return InferenceInfo(
|
|
62
|
+
input_tokens=prompt_len,
|
|
63
|
+
output_tokens=output_len,
|
|
64
|
+
)
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# Copyright 2025 The Kubernetes Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import time
|
|
18
|
+
from typing import Any, List
|
|
19
|
+
|
|
20
|
+
from aiohttp import ClientResponse
|
|
21
|
+
from inference_perf.apis import InferenceAPIData, InferenceInfo
|
|
22
|
+
from inference_perf.utils.custom_tokenizer import CustomTokenizer
|
|
23
|
+
from inference_perf.config import APIConfig, APIType
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class CompletionAPIData(InferenceAPIData):
|
|
27
|
+
prompt: str
|
|
28
|
+
max_tokens: int = 0
|
|
29
|
+
|
|
30
|
+
def get_api_type(self) -> APIType:
|
|
31
|
+
return APIType.Completion
|
|
32
|
+
|
|
33
|
+
def get_route(self) -> str:
|
|
34
|
+
return "/v1/completions"
|
|
35
|
+
|
|
36
|
+
def to_payload(self, model_name: str, max_tokens: int, ignore_eos: bool, streaming: bool) -> dict[str, Any]:
|
|
37
|
+
if self.max_tokens == 0:
|
|
38
|
+
self.max_tokens = max_tokens
|
|
39
|
+
return {
|
|
40
|
+
"model": model_name,
|
|
41
|
+
"prompt": self.prompt,
|
|
42
|
+
"max_tokens": self.max_tokens,
|
|
43
|
+
"ignore_eos": ignore_eos,
|
|
44
|
+
"stream": streaming,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
async def process_response(self, response: ClientResponse, config: APIConfig, tokenizer: CustomTokenizer) -> InferenceInfo:
|
|
48
|
+
if config.streaming:
|
|
49
|
+
output_text = ""
|
|
50
|
+
output_token_times: List[float] = []
|
|
51
|
+
async for chunk_bytes in response.content:
|
|
52
|
+
chunk_bytes = chunk_bytes.strip()
|
|
53
|
+
output_token_times.append(time.perf_counter())
|
|
54
|
+
if not chunk_bytes:
|
|
55
|
+
continue
|
|
56
|
+
# After removing the "data: " prefix, each chunk decodes to a response json for a single token or "[DONE]" if end of stream
|
|
57
|
+
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
|
58
|
+
if chunk != "[DONE]":
|
|
59
|
+
data = json.loads(chunk)
|
|
60
|
+
if choices := data.get("choices"):
|
|
61
|
+
text = choices[0].get("text")
|
|
62
|
+
output_text += text
|
|
63
|
+
prompt_len = tokenizer.count_tokens(self.prompt)
|
|
64
|
+
output_len = tokenizer.count_tokens(output_text)
|
|
65
|
+
return InferenceInfo(
|
|
66
|
+
input_tokens=prompt_len,
|
|
67
|
+
output_tokens=output_len,
|
|
68
|
+
output_token_times=output_token_times,
|
|
69
|
+
)
|
|
70
|
+
else:
|
|
71
|
+
data = await response.json()
|
|
72
|
+
prompt_len = tokenizer.count_tokens(self.prompt)
|
|
73
|
+
choices = data.get("choices", [])
|
|
74
|
+
if len(choices) == 0:
|
|
75
|
+
return InferenceInfo(input_tokens=prompt_len)
|
|
76
|
+
output_text = choices[0].get("text", "")
|
|
77
|
+
output_len = tokenizer.count_tokens(output_text)
|
|
78
|
+
return InferenceInfo(input_tokens=prompt_len, output_tokens=output_len)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright 2025 The Kubernetes Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from .base import StorageClient
|
|
15
|
+
from .local import LocalStorageClient
|
|
16
|
+
from .gcs import GoogleCloudStorageClient
|
|
17
|
+
from .s3 import SimpleStorageServiceClient
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
__all__ = ["StorageClient", "LocalStorageClient", "GoogleCloudStorageClient","SimpleStorageServiceClient"]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright 2025 The Kubernetes Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
17
|
+
from typing import List
|
|
18
|
+
from inference_perf.config import StorageConfigBase
|
|
19
|
+
from inference_perf.utils import ReportFile
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class StorageClient(ABC):
|
|
25
|
+
def __init__(self, config: StorageConfigBase) -> None:
|
|
26
|
+
self.config = config
|
|
27
|
+
logger.info(f"Report files will be stored at: {self.config.path}")
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def save_report(self, reports: List[ReportFile]) -> None:
|
|
31
|
+
raise NotImplementedError()
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Copyright 2025 The Kubernetes Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import json
|
|
15
|
+
import logging
|
|
16
|
+
from typing import List
|
|
17
|
+
from google.cloud import storage
|
|
18
|
+
from google.cloud.exceptions import GoogleCloudError
|
|
19
|
+
from inference_perf.client.filestorage import StorageClient
|
|
20
|
+
from inference_perf.config import GoogleCloudStorageConfig
|
|
21
|
+
from inference_perf.utils import ReportFile
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class GoogleCloudStorageClient(StorageClient):
|
|
27
|
+
def __init__(self, config: GoogleCloudStorageConfig) -> None:
|
|
28
|
+
super().__init__(config=config)
|
|
29
|
+
logger.debug("Created new GCS client")
|
|
30
|
+
self.output_bucket = config.bucket_name
|
|
31
|
+
self.client = storage.Client()
|
|
32
|
+
|
|
33
|
+
self.bucket = self.client.lookup_bucket(config.bucket_name)
|
|
34
|
+
if self.bucket is None:
|
|
35
|
+
raise ValueError(f"GCS bucket '{config.bucket_name}' does not exist or is inaccessible.")
|
|
36
|
+
|
|
37
|
+
def save_report(self, reports: List[ReportFile]) -> None:
|
|
38
|
+
filenames = [report.get_filename() for report in reports]
|
|
39
|
+
if len(filenames) != len(set(filenames)):
|
|
40
|
+
raise ValueError("Duplicate filenames detected", filenames)
|
|
41
|
+
|
|
42
|
+
for _, report in enumerate(reports):
|
|
43
|
+
filename = report.get_filename()
|
|
44
|
+
blob_path = f"{self.config.path if self.config.path else ''}/{self.config.report_file_prefix if self.config.report_file_prefix else ''}{filename}"
|
|
45
|
+
blob = self.bucket.blob(blob_path)
|
|
46
|
+
|
|
47
|
+
if blob.exists():
|
|
48
|
+
logger.info(f"Skipping upload: gs://{self.output_bucket}/{blob_path} already exists")
|
|
49
|
+
continue
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
blob.upload_from_string(json.dumps(report.get_contents()), content_type="application/json")
|
|
53
|
+
logger.info(f"Uploaded gs://{self.output_bucket}/{blob_path}")
|
|
54
|
+
except GoogleCloudError as e:
|
|
55
|
+
logger.error(f"Failed to upload {blob_path}: {e}")
|