inference-perf 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. inference_perf/__init__.py +16 -0
  2. inference_perf/analysis/__init__.py +17 -0
  3. inference_perf/analysis/analyze.py +276 -0
  4. inference_perf/apis/__init__.py +26 -0
  5. inference_perf/apis/base.py +60 -0
  6. inference_perf/apis/chat.py +64 -0
  7. inference_perf/apis/completion.py +78 -0
  8. inference_perf/client/filestorage/__init__.py +20 -0
  9. inference_perf/client/filestorage/base.py +31 -0
  10. inference_perf/client/filestorage/gcs.py +55 -0
  11. inference_perf/client/filestorage/local.py +38 -0
  12. inference_perf/client/filestorage/s3.py +59 -0
  13. inference_perf/client/metricsclient/__init__.py +17 -0
  14. inference_perf/client/metricsclient/base.py +72 -0
  15. inference_perf/client/metricsclient/mock_client.py +28 -0
  16. inference_perf/client/metricsclient/prometheus_client/__init__.py +18 -0
  17. inference_perf/client/metricsclient/prometheus_client/base.py +284 -0
  18. inference_perf/client/metricsclient/prometheus_client/google_managed_prometheus_client.py +43 -0
  19. inference_perf/client/modelserver/__init__.py +19 -0
  20. inference_perf/client/modelserver/base.py +78 -0
  21. inference_perf/client/modelserver/mock_client.py +52 -0
  22. inference_perf/client/modelserver/vllm_client.py +245 -0
  23. inference_perf/client/requestdatacollector/__init__.py +23 -0
  24. inference_perf/client/requestdatacollector/base.py +31 -0
  25. inference_perf/client/requestdatacollector/local.py +30 -0
  26. inference_perf/client/requestdatacollector/multiprocess.py +56 -0
  27. inference_perf/config.py +193 -0
  28. inference_perf/datagen/__init__.py +28 -0
  29. inference_perf/datagen/base.py +66 -0
  30. inference_perf/datagen/hf_sharegpt_datagen.py +94 -0
  31. inference_perf/datagen/mock_datagen.py +43 -0
  32. inference_perf/datagen/random_datagen.py +105 -0
  33. inference_perf/datagen/shared_prefix_datagen.py +99 -0
  34. inference_perf/datagen/synthetic_datagen.py +591 -0
  35. inference_perf/loadgen/__init__.py +16 -0
  36. inference_perf/loadgen/load_generator.py +212 -0
  37. inference_perf/loadgen/load_timer.py +78 -0
  38. inference_perf/logger.py +35 -0
  39. inference_perf/main.py +244 -0
  40. inference_perf/reportgen/__init__.py +16 -0
  41. inference_perf/reportgen/base.py +272 -0
  42. inference_perf/utils/__init__.py +17 -0
  43. inference_perf/utils/custom_tokenizer.py +30 -0
  44. inference_perf/utils/distribution.py +60 -0
  45. inference_perf/utils/report_file.py +30 -0
  46. inference_perf-0.1.1.dist-info/METADATA +232 -0
  47. inference_perf-0.1.1.dist-info/RECORD +51 -0
  48. inference_perf-0.1.1.dist-info/WHEEL +5 -0
  49. inference_perf-0.1.1.dist-info/entry_points.txt +2 -0
  50. inference_perf-0.1.1.dist-info/licenses/LICENSE +201 -0
  51. inference_perf-0.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 The Kubernetes Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from .main import main_cli
15
+
16
+ __all__ = ["main_cli"]
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 The Kubernetes Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .analyze import analyze_reports
16
+
17
+ __all__ = ["analyze_reports"]
@@ -0,0 +1,276 @@
1
+ # Copyright 2025 The Kubernetes Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import logging
17
+ import operator
18
+ from pathlib import Path
19
+ from typing import Any, Dict, List, Tuple
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def _extract_latency_metric(latency_data: Dict[str, Any], metric_name: str, convert_to_ms: bool = False) -> float | None:
25
+ """Helper to extract a metric's mean value from latency data."""
26
+ metric_data = latency_data.get(metric_name)
27
+ if isinstance(metric_data, dict):
28
+ mean_val = metric_data.get("mean")
29
+ if isinstance(mean_val, (int, float)):
30
+ return mean_val * 1000 if convert_to_ms else mean_val
31
+ return None
32
+
33
+
34
+ def _extract_throughput_metric(throughput_data: Dict[str, Any], metric_name: str) -> float | None:
35
+ """Helper to extract a throughput metric's value."""
36
+ metric_value = throughput_data.get(metric_name)
37
+ if isinstance(metric_value, (int, float)):
38
+ return float(metric_value)
39
+ return None
40
+
41
+
42
+ def _generate_plot(charts_to_generate: List[Dict[str, Any]], suptitle: str, output_path: Path) -> None:
43
+ """Generates and saves a plot with multiple subplots."""
44
+ import matplotlib.pyplot as plt
45
+
46
+ if not charts_to_generate:
47
+ logger.warning(f"No data available to generate chart: {output_path.name}")
48
+ return
49
+
50
+ num_charts = len(charts_to_generate)
51
+ fig, axes = plt.subplots(1, num_charts, figsize=(7 * num_charts, 6), squeeze=False)
52
+ fig.suptitle(suptitle, fontsize=16)
53
+
54
+ for i, chart_info in enumerate(charts_to_generate):
55
+ ax = axes[0, i]
56
+ data = chart_info["data"]
57
+ qps_values = [x[0] for x in data]
58
+ y_values = [x[1] for x in data]
59
+
60
+ ax.plot(qps_values, y_values, marker="o", linestyle="-")
61
+ ax.set_title(chart_info["title"])
62
+ ax.set_xlabel(chart_info.get("xlabel", "QPS (requested rate)"))
63
+ ax.set_ylabel(chart_info["ylabel"])
64
+ ax.grid(True)
65
+
66
+ fig.tight_layout(rect=(0, 0.03, 1, 0.95))
67
+ plt.savefig(output_path)
68
+ logger.info(f"Chart saved to {output_path}")
69
+ plt.close(fig)
70
+
71
+
72
+ def analyze_reports(report_dir: str) -> None:
73
+ """
74
+ Analyzes performance reports to generate charts.
75
+
76
+ Args:
77
+ report_dir: The directory containing the report files.
78
+ """
79
+ try:
80
+ # Check for matplotlib and provide a helpful error message if it's not installed.
81
+ import matplotlib # noqa: F401
82
+ except ImportError:
83
+ logger.error(
84
+ "matplotlib is not installed. Please install it to use the --analyze feature.\n"
85
+ "You can install it via 'pip install .[analysis]'"
86
+ )
87
+ return
88
+
89
+ logger.info(f"Analyzing reports in {report_dir}")
90
+
91
+ # Find stage lifecycle metrics files
92
+ report_path = Path(report_dir)
93
+ stage_files = list(report_path.glob("stage_*_lifecycle_metrics.json"))
94
+
95
+ if not stage_files:
96
+ logger.error(f"No stage lifecycle metrics files found in {report_dir}")
97
+ return
98
+
99
+ # Latency data
100
+ qps_vs_ttft: List[Tuple[float, float]] = []
101
+ qps_vs_ntpot: List[Tuple[float, float]] = []
102
+ qps_vs_itl: List[Tuple[float, float]] = []
103
+ # Throughput data
104
+ qps_vs_itps: List[Tuple[float, float]] = []
105
+ qps_vs_otps: List[Tuple[float, float]] = []
106
+ qps_vs_ttps: List[Tuple[float, float]] = []
107
+ # Throughput vs Latency data
108
+ ttft_vs_otps: List[Tuple[float, float]] = []
109
+ ntpot_vs_otps: List[Tuple[float, float]] = []
110
+ itl_vs_otps: List[Tuple[float, float]] = []
111
+
112
+ for stage_file in stage_files:
113
+ try:
114
+ with open(stage_file, "r") as f:
115
+ report_data = json.load(f)
116
+
117
+ # Get QPS from report file
118
+ qps = report_data.get("load_summary", {}).get("requested_rate")
119
+ if qps is None:
120
+ logger.warning(f"Could not find requested_rate in {stage_file.name}. Skipping.")
121
+ continue
122
+
123
+ success_data = report_data.get("successes", {})
124
+ if not success_data:
125
+ logger.warning(f"No success data in {stage_file.name}. Skipping.")
126
+ continue
127
+
128
+ # Extract latency metrics if they exist
129
+ ttft, ntpot, itl = None, None, None
130
+ latency_data = success_data.get("latency", {})
131
+ if latency_data:
132
+ ttft = _extract_latency_metric(latency_data, "time_to_first_token", convert_to_ms=True)
133
+ if ttft is not None:
134
+ qps_vs_ttft.append((qps, ttft))
135
+
136
+ ntpot = _extract_latency_metric(latency_data, "normalized_time_per_output_token", convert_to_ms=True)
137
+ if ntpot is not None:
138
+ qps_vs_ntpot.append((qps, ntpot))
139
+
140
+ itl = _extract_latency_metric(latency_data, "inter_token_latency", convert_to_ms=True)
141
+ if itl is not None:
142
+ qps_vs_itl.append((qps, itl))
143
+
144
+ # Extract throughput metrics if they exist
145
+ otps = None
146
+ throughput_data = success_data.get("throughput", {})
147
+ if throughput_data:
148
+ itps = _extract_throughput_metric(throughput_data, "input_tokens_per_sec")
149
+ if itps is not None:
150
+ qps_vs_itps.append((qps, itps))
151
+
152
+ otps = _extract_throughput_metric(throughput_data, "output_tokens_per_sec")
153
+ if otps is not None:
154
+ qps_vs_otps.append((qps, otps))
155
+
156
+ ttps = _extract_throughput_metric(throughput_data, "total_tokens_per_sec")
157
+ if ttps is not None:
158
+ qps_vs_ttps.append((qps, ttps))
159
+
160
+ # Populate latency vs throughput data
161
+ if otps is not None:
162
+ if ttft is not None:
163
+ ttft_vs_otps.append((ttft, otps))
164
+ if ntpot is not None:
165
+ ntpot_vs_otps.append((ntpot, otps))
166
+ if itl is not None:
167
+ itl_vs_otps.append((itl, otps))
168
+
169
+ except json.JSONDecodeError:
170
+ logger.error(f"Error decoding JSON from {stage_file.name}")
171
+ continue
172
+ except Exception as e:
173
+ logger.error(f"An unexpected error occurred while processing {stage_file.name}: {e}")
174
+ continue
175
+
176
+ # --- Generate Latency Plot ---
177
+ latency_charts_to_generate = []
178
+ if qps_vs_ttft:
179
+ latency_charts_to_generate.append(
180
+ {
181
+ "title": "Time to First Token vs. QPS",
182
+ "ylabel": "Mean TTFT (ms)",
183
+ "data": sorted(qps_vs_ttft, key=operator.itemgetter(0)),
184
+ }
185
+ )
186
+ if qps_vs_ntpot:
187
+ latency_charts_to_generate.append(
188
+ {
189
+ "title": "Norm. Time per Output Token vs. QPS",
190
+ "ylabel": "Mean Norm. Time (ms/token)",
191
+ "data": sorted(qps_vs_ntpot, key=operator.itemgetter(0)),
192
+ }
193
+ )
194
+ if qps_vs_itl:
195
+ latency_charts_to_generate.append(
196
+ {
197
+ "title": "Inter-Token Latency vs. QPS",
198
+ "ylabel": "Mean ITL (ms)",
199
+ "data": sorted(qps_vs_itl, key=operator.itemgetter(0)),
200
+ }
201
+ )
202
+
203
+ _generate_plot(
204
+ latency_charts_to_generate,
205
+ "Latency vs Request Rate",
206
+ report_path / "latency_vs_qps.png",
207
+ )
208
+
209
+ # --- Generate Throughput Plot ---
210
+ throughput_charts_to_generate = []
211
+ if qps_vs_itps:
212
+ throughput_charts_to_generate.append(
213
+ {
214
+ "title": "Input Tokens/sec vs. QPS",
215
+ "ylabel": "Tokens/sec",
216
+ "data": sorted(qps_vs_itps, key=operator.itemgetter(0)),
217
+ }
218
+ )
219
+ if qps_vs_otps:
220
+ throughput_charts_to_generate.append(
221
+ {
222
+ "title": "Output Tokens/sec vs. QPS",
223
+ "ylabel": "Tokens/sec",
224
+ "data": sorted(qps_vs_otps, key=operator.itemgetter(0)),
225
+ }
226
+ )
227
+ if qps_vs_ttps:
228
+ throughput_charts_to_generate.append(
229
+ {
230
+ "title": "Total Tokens/sec vs. QPS",
231
+ "ylabel": "Tokens/sec",
232
+ "data": sorted(qps_vs_ttps, key=operator.itemgetter(0)),
233
+ }
234
+ )
235
+
236
+ _generate_plot(
237
+ throughput_charts_to_generate,
238
+ "Throughput vs Request Rate",
239
+ report_path / "throughput_vs_qps.png",
240
+ )
241
+
242
+ # --- Generate Throughput vs Latency Curve Plot ---
243
+ throughput_latency_charts_to_generate = []
244
+ if ntpot_vs_otps:
245
+ throughput_latency_charts_to_generate.append(
246
+ {
247
+ "title": "Throughput vs. Norm. Time per Output Token",
248
+ "xlabel": "Mean Norm. Time (ms/token)",
249
+ "ylabel": "Output Tokens/sec",
250
+ "data": sorted(ntpot_vs_otps, key=operator.itemgetter(0)),
251
+ }
252
+ )
253
+ if ttft_vs_otps:
254
+ throughput_latency_charts_to_generate.append(
255
+ {
256
+ "title": "Throughput vs. Time to First Token",
257
+ "xlabel": "Mean TTFT (ms)",
258
+ "ylabel": "Output Tokens/sec",
259
+ "data": sorted(ttft_vs_otps, key=operator.itemgetter(0)),
260
+ }
261
+ )
262
+ if itl_vs_otps:
263
+ throughput_latency_charts_to_generate.append(
264
+ {
265
+ "title": "Throughput vs. Inter-Token Latency",
266
+ "xlabel": "Mean ITL (ms)",
267
+ "ylabel": "Output Tokens/sec",
268
+ "data": sorted(itl_vs_otps, key=operator.itemgetter(0)),
269
+ }
270
+ )
271
+
272
+ _generate_plot(
273
+ throughput_latency_charts_to_generate,
274
+ "Latency vs Throughput",
275
+ report_path / "throughput_vs_latency.png",
276
+ )
@@ -0,0 +1,26 @@
1
+ # Copyright 2025 The Kubernetes Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from .base import InferenceAPIData, InferenceInfo, RequestLifecycleMetric, ErrorResponseInfo
15
+ from .chat import ChatCompletionAPIData, ChatMessage
16
+ from .completion import CompletionAPIData
17
+
18
+ __all__ = [
19
+ "InferenceAPIData",
20
+ "InferenceInfo",
21
+ "RequestLifecycleMetric",
22
+ "ErrorResponseInfo",
23
+ "ChatCompletionAPIData",
24
+ "ChatMessage",
25
+ "CompletionAPIData",
26
+ ]
@@ -0,0 +1,60 @@
1
+ # Copyright 2025 The Kubernetes Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import abstractmethod
16
+ from typing import Any, List, Optional
17
+ from aiohttp import ClientResponse
18
+ from pydantic import BaseModel
19
+ from inference_perf.utils.custom_tokenizer import CustomTokenizer
20
+ from inference_perf.config import APIConfig, APIType
21
+
22
+
23
+ class InferenceInfo(BaseModel):
24
+ input_tokens: int = 0
25
+ output_tokens: int = 0
26
+ output_token_times: List[float] = []
27
+
28
+
29
+ class ErrorResponseInfo(BaseModel):
30
+ error_type: str
31
+ error_msg: str
32
+
33
+
34
+ class RequestLifecycleMetric(BaseModel):
35
+ stage_id: Optional[int] = None
36
+ scheduled_time: float
37
+ start_time: float
38
+ end_time: float
39
+ request_data: str
40
+ response_data: Optional[str] = None
41
+ info: InferenceInfo
42
+ error: Optional[ErrorResponseInfo]
43
+
44
+
45
+ class InferenceAPIData(BaseModel):
46
+ @abstractmethod
47
+ def get_api_type(self) -> APIType:
48
+ raise NotImplementedError
49
+
50
+ @abstractmethod
51
+ def get_route(self) -> str:
52
+ raise NotImplementedError
53
+
54
+ @abstractmethod
55
+ def to_payload(self, model_name: str, max_tokens: int, ignore_eos: bool, streaming: bool) -> dict[str, Any]:
56
+ raise NotImplementedError
57
+
58
+ @abstractmethod
59
+ async def process_response(self, response: ClientResponse, config: APIConfig, tokenizer: CustomTokenizer) -> InferenceInfo:
60
+ raise NotImplementedError
@@ -0,0 +1,64 @@
1
+ # Copyright 2025 The Kubernetes Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, List
16
+ from aiohttp import ClientResponse
17
+ from pydantic import BaseModel
18
+ from inference_perf.apis import InferenceAPIData, InferenceInfo
19
+ from inference_perf.utils.custom_tokenizer import CustomTokenizer
20
+ from inference_perf.config import APIConfig, APIType
21
+
22
+
23
+ class ChatMessage(BaseModel):
24
+ role: str
25
+ content: str
26
+
27
+
28
+ class ChatCompletionAPIData(InferenceAPIData):
29
+ messages: List[ChatMessage]
30
+ max_tokens: int = 0
31
+
32
+ def get_api_type(self) -> APIType:
33
+ return APIType.Chat
34
+
35
+ def get_route(self) -> str:
36
+ return "/v1/chat/completions"
37
+
38
+ def to_payload(self, model_name: str, max_tokens: int, ignore_eos: bool, streaming: bool) -> dict[str, Any]:
39
+ if streaming:
40
+ raise Exception("Generating streaming request payloads for the Chat API is not currently supported.")
41
+ if self.max_tokens == 0:
42
+ self.max_tokens = max_tokens
43
+ return {
44
+ "model": model_name,
45
+ "messages": [{"role": m.role, "content": m.content} for m in self.messages],
46
+ "max_tokens": self.max_tokens,
47
+ "ignore_eos": ignore_eos,
48
+ }
49
+
50
+ async def process_response(self, response: ClientResponse, config: APIConfig, tokenizer: CustomTokenizer) -> InferenceInfo:
51
+ if config.streaming:
52
+ raise Exception("Decoding streamed responses from the Chat API is not currently supported")
53
+ else:
54
+ data = await response.json()
55
+ prompt_len = tokenizer.count_tokens("".join([m.content for m in self.messages]))
56
+ choices = data.get("choices", [])
57
+ if len(choices) == 0:
58
+ return InferenceInfo(input_tokens=prompt_len)
59
+ output_text = "".join([choice.get("message", {}).get("content", "") for choice in choices])
60
+ output_len = tokenizer.count_tokens(output_text)
61
+ return InferenceInfo(
62
+ input_tokens=prompt_len,
63
+ output_tokens=output_len,
64
+ )
@@ -0,0 +1,78 @@
1
+ # Copyright 2025 The Kubernetes Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import json
17
+ import time
18
+ from typing import Any, List
19
+
20
+ from aiohttp import ClientResponse
21
+ from inference_perf.apis import InferenceAPIData, InferenceInfo
22
+ from inference_perf.utils.custom_tokenizer import CustomTokenizer
23
+ from inference_perf.config import APIConfig, APIType
24
+
25
+
26
+ class CompletionAPIData(InferenceAPIData):
27
+ prompt: str
28
+ max_tokens: int = 0
29
+
30
+ def get_api_type(self) -> APIType:
31
+ return APIType.Completion
32
+
33
+ def get_route(self) -> str:
34
+ return "/v1/completions"
35
+
36
+ def to_payload(self, model_name: str, max_tokens: int, ignore_eos: bool, streaming: bool) -> dict[str, Any]:
37
+ if self.max_tokens == 0:
38
+ self.max_tokens = max_tokens
39
+ return {
40
+ "model": model_name,
41
+ "prompt": self.prompt,
42
+ "max_tokens": self.max_tokens,
43
+ "ignore_eos": ignore_eos,
44
+ "stream": streaming,
45
+ }
46
+
47
+ async def process_response(self, response: ClientResponse, config: APIConfig, tokenizer: CustomTokenizer) -> InferenceInfo:
48
+ if config.streaming:
49
+ output_text = ""
50
+ output_token_times: List[float] = []
51
+ async for chunk_bytes in response.content:
52
+ chunk_bytes = chunk_bytes.strip()
53
+ output_token_times.append(time.perf_counter())
54
+ if not chunk_bytes:
55
+ continue
56
+ # After removing the "data: " prefix, each chunk decodes to a response json for a single token or "[DONE]" if end of stream
57
+ chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
58
+ if chunk != "[DONE]":
59
+ data = json.loads(chunk)
60
+ if choices := data.get("choices"):
61
+ text = choices[0].get("text")
62
+ output_text += text
63
+ prompt_len = tokenizer.count_tokens(self.prompt)
64
+ output_len = tokenizer.count_tokens(output_text)
65
+ return InferenceInfo(
66
+ input_tokens=prompt_len,
67
+ output_tokens=output_len,
68
+ output_token_times=output_token_times,
69
+ )
70
+ else:
71
+ data = await response.json()
72
+ prompt_len = tokenizer.count_tokens(self.prompt)
73
+ choices = data.get("choices", [])
74
+ if len(choices) == 0:
75
+ return InferenceInfo(input_tokens=prompt_len)
76
+ output_text = choices[0].get("text", "")
77
+ output_len = tokenizer.count_tokens(output_text)
78
+ return InferenceInfo(input_tokens=prompt_len, output_tokens=output_len)
@@ -0,0 +1,20 @@
1
+ # Copyright 2025 The Kubernetes Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from .base import StorageClient
15
+ from .local import LocalStorageClient
16
+ from .gcs import GoogleCloudStorageClient
17
+ from .s3 import SimpleStorageServiceClient
18
+
19
+
20
+ __all__ = ["StorageClient", "LocalStorageClient", "GoogleCloudStorageClient","SimpleStorageServiceClient"]
@@ -0,0 +1,31 @@
1
+ # Copyright 2025 The Kubernetes Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from abc import ABC, abstractmethod
17
+ from typing import List
18
+ from inference_perf.config import StorageConfigBase
19
+ from inference_perf.utils import ReportFile
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class StorageClient(ABC):
25
+ def __init__(self, config: StorageConfigBase) -> None:
26
+ self.config = config
27
+ logger.info(f"Report files will be stored at: {self.config.path}")
28
+
29
+ @abstractmethod
30
+ def save_report(self, reports: List[ReportFile]) -> None:
31
+ raise NotImplementedError()
@@ -0,0 +1,55 @@
1
+ # Copyright 2025 The Kubernetes Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import logging
16
+ from typing import List
17
+ from google.cloud import storage
18
+ from google.cloud.exceptions import GoogleCloudError
19
+ from inference_perf.client.filestorage import StorageClient
20
+ from inference_perf.config import GoogleCloudStorageConfig
21
+ from inference_perf.utils import ReportFile
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class GoogleCloudStorageClient(StorageClient):
27
+ def __init__(self, config: GoogleCloudStorageConfig) -> None:
28
+ super().__init__(config=config)
29
+ logger.debug("Created new GCS client")
30
+ self.output_bucket = config.bucket_name
31
+ self.client = storage.Client()
32
+
33
+ self.bucket = self.client.lookup_bucket(config.bucket_name)
34
+ if self.bucket is None:
35
+ raise ValueError(f"GCS bucket '{config.bucket_name}' does not exist or is inaccessible.")
36
+
37
+ def save_report(self, reports: List[ReportFile]) -> None:
38
+ filenames = [report.get_filename() for report in reports]
39
+ if len(filenames) != len(set(filenames)):
40
+ raise ValueError("Duplicate filenames detected", filenames)
41
+
42
+ for _, report in enumerate(reports):
43
+ filename = report.get_filename()
44
+ blob_path = f"{self.config.path if self.config.path else ''}/{self.config.report_file_prefix if self.config.report_file_prefix else ''}{filename}"
45
+ blob = self.bucket.blob(blob_path)
46
+
47
+ if blob.exists():
48
+ logger.info(f"Skipping upload: gs://{self.output_bucket}/{blob_path} already exists")
49
+ continue
50
+
51
+ try:
52
+ blob.upload_from_string(json.dumps(report.get_contents()), content_type="application/json")
53
+ logger.info(f"Uploaded gs://{self.output_bucket}/{blob_path}")
54
+ except GoogleCloudError as e:
55
+ logger.error(f"Failed to upload {blob_path}: {e}")