tpu-inference 0.11.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (123) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  53. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  54. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  55. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  56. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  57. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  58. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  59. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  60. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  61. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  63. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  65. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  67. tpu_inference/logger.py +10 -0
  68. tpu_inference/lora/__init__.py +0 -0
  69. tpu_inference/lora/torch_lora_ops.py +103 -0
  70. tpu_inference/lora/torch_punica_tpu.py +308 -0
  71. tpu_inference/mock/__init__.py +0 -0
  72. tpu_inference/mock/vllm_config_utils.py +28 -0
  73. tpu_inference/mock/vllm_envs.py +1233 -0
  74. tpu_inference/mock/vllm_logger.py +212 -0
  75. tpu_inference/mock/vllm_logging_utils.py +15 -0
  76. tpu_inference/models/__init__.py +0 -0
  77. tpu_inference/models/jax/__init__.py +0 -0
  78. tpu_inference/models/jax/deepseek_v3.py +868 -0
  79. tpu_inference/models/jax/llama3.py +366 -0
  80. tpu_inference/models/jax/llama4.py +473 -0
  81. tpu_inference/models/jax/llama_eagle3.py +333 -0
  82. tpu_inference/models/jax/phi3.py +376 -0
  83. tpu_inference/models/jax/qwen2.py +375 -0
  84. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  85. tpu_inference/models/jax/qwen3.py +302 -0
  86. tpu_inference/models/jax/utils/__init__.py +0 -0
  87. tpu_inference/models/jax/utils/file_utils.py +96 -0
  88. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  89. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  90. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  91. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  92. tpu_inference/models/vllm/__init__.py +0 -0
  93. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  94. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  95. tpu_inference/platforms/__init__.py +2 -0
  96. tpu_inference/platforms/tpu_jax.py +257 -0
  97. tpu_inference/runner/__init__.py +0 -0
  98. tpu_inference/runner/block_table_jax.py +122 -0
  99. tpu_inference/runner/compilation_manager.py +672 -0
  100. tpu_inference/runner/input_batch_jax.py +435 -0
  101. tpu_inference/runner/kv_cache.py +119 -0
  102. tpu_inference/runner/kv_cache_manager.py +460 -0
  103. tpu_inference/runner/lora_utils.py +92 -0
  104. tpu_inference/runner/multimodal_manager.py +208 -0
  105. tpu_inference/runner/persistent_batch_manager.py +244 -0
  106. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  107. tpu_inference/runner/structured_decoding_manager.py +89 -0
  108. tpu_inference/runner/tpu_jax_runner.py +771 -0
  109. tpu_inference/runner/utils.py +426 -0
  110. tpu_inference/spec_decode/__init__.py +0 -0
  111. tpu_inference/spec_decode/jax/__init__.py +0 -0
  112. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  113. tpu_inference/tpu_info.py +77 -0
  114. tpu_inference/utils.py +294 -0
  115. tpu_inference/worker/__init__.py +0 -0
  116. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  117. tpu_inference/worker/base.py +100 -0
  118. tpu_inference/worker/tpu_worker_jax.py +321 -0
  119. tpu_inference-0.11.1rc1.dist-info/METADATA +101 -0
  120. tpu_inference-0.11.1rc1.dist-info/RECORD +123 -0
  121. tpu_inference-0.11.1rc1.dist-info/WHEEL +5 -0
  122. tpu_inference-0.11.1rc1.dist-info/licenses/LICENSE +201 -0
  123. tpu_inference-0.11.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,212 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """Logging configuration for vLLM."""
4
+ import datetime
5
+ import json
6
+ import logging
7
+ import sys
8
+ from collections.abc import Hashable
9
+ from functools import lru_cache, partial
10
+ from logging import Logger
11
+ from logging.config import dictConfig
12
+ from os import path
13
+ from types import MethodType
14
+ from typing import Any, cast
15
+
16
+ import tpu_inference.mock.vllm_envs as envs
17
+
18
+ VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
19
+ VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
20
+ VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL
21
+ VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX
22
+
23
+ _FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s "
24
+ "[%(filename)s:%(lineno)d] %(message)s")
25
+ _DATE_FORMAT = "%m-%d %H:%M:%S"
26
+
27
+ DEFAULT_LOGGING_CONFIG = {
28
+ "formatters": {
29
+ "vllm": {
30
+ "class": "tpu_inference.vllm_logging_utils.NewLineFormatter",
31
+ "datefmt": _DATE_FORMAT,
32
+ "format": _FORMAT,
33
+ },
34
+ },
35
+ "handlers": {
36
+ "vllm": {
37
+ "class": "logging.StreamHandler",
38
+ "formatter": "vllm",
39
+ "level": VLLM_LOGGING_LEVEL,
40
+ "stream": "ext://sys.stdout",
41
+ },
42
+ },
43
+ "loggers": {
44
+ "vllm": {
45
+ "handlers": ["vllm"],
46
+ "level": "DEBUG",
47
+ "propagate": False,
48
+ },
49
+ },
50
+ "version": 1,
51
+ "disable_existing_loggers": False
52
+ }
53
+
54
+
55
+ @lru_cache
56
+ def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None:
57
+ # Set the stacklevel to 2 to print the original caller's line info
58
+ logger.debug(msg, *args, stacklevel=2)
59
+
60
+
61
+ @lru_cache
62
+ def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None:
63
+ # Set the stacklevel to 2 to print the original caller's line info
64
+ logger.info(msg, *args, stacklevel=2)
65
+
66
+
67
+ @lru_cache
68
+ def _print_warning_once(logger: Logger, msg: str, *args: Hashable) -> None:
69
+ # Set the stacklevel to 2 to print the original caller's line info
70
+ logger.warning(msg, *args, stacklevel=2)
71
+
72
+
73
+ class _VllmLogger(Logger):
74
+ """
75
+ Note:
76
+ This class is just to provide type information.
77
+ We actually patch the methods directly on the [`logging.Logger`][]
78
+ instance to avoid conflicting with other libraries such as
79
+ `intel_extension_for_pytorch.utils._logger`.
80
+ """
81
+
82
+ def debug_once(self, msg: str, *args: Hashable) -> None:
83
+ """
84
+ As [`debug`][logging.Logger.debug], but subsequent calls with
85
+ the same message are silently dropped.
86
+ """
87
+ _print_debug_once(self, msg, *args)
88
+
89
+ def info_once(self, msg: str, *args: Hashable) -> None:
90
+ """
91
+ As [`info`][logging.Logger.info], but subsequent calls with
92
+ the same message are silently dropped.
93
+ """
94
+ _print_info_once(self, msg, *args)
95
+
96
+ def warning_once(self, msg: str, *args: Hashable) -> None:
97
+ """
98
+ As [`warning`][logging.Logger.warning], but subsequent calls with
99
+ the same message are silently dropped.
100
+ """
101
+ _print_warning_once(self, msg, *args)
102
+
103
+
104
+ # Pre-defined methods mapping to avoid repeated dictionary creation
105
+ _METHODS_TO_PATCH = {
106
+ "debug_once": _print_debug_once,
107
+ "info_once": _print_info_once,
108
+ "warning_once": _print_warning_once,
109
+ }
110
+
111
+
112
+ def _configure_vllm_root_logger() -> None:
113
+ logging_config = dict[str, Any]()
114
+
115
+ if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH:
116
+ raise RuntimeError(
117
+ "VLLM_CONFIGURE_LOGGING evaluated to false, but "
118
+ "VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH "
119
+ "implies VLLM_CONFIGURE_LOGGING. Please enable "
120
+ "VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.")
121
+
122
+ if VLLM_CONFIGURE_LOGGING:
123
+ logging_config = DEFAULT_LOGGING_CONFIG
124
+
125
+ if VLLM_LOGGING_CONFIG_PATH:
126
+ if not path.exists(VLLM_LOGGING_CONFIG_PATH):
127
+ raise RuntimeError(
128
+ "Could not load logging config. File does not exist: %s",
129
+ VLLM_LOGGING_CONFIG_PATH)
130
+ with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
131
+ custom_config = json.loads(file.read())
132
+
133
+ if not isinstance(custom_config, dict):
134
+ raise ValueError("Invalid logging config. Expected dict, got %s.",
135
+ type(custom_config).__name__)
136
+ logging_config = custom_config
137
+
138
+ for formatter in logging_config.get("formatters", {}).values():
139
+ # This provides backwards compatibility after #10134.
140
+ if formatter.get(
141
+ "class"
142
+ ) == "tpu_inference.vllm_logging_utils.NewLineFormatter":
143
+ formatter[
144
+ "class"] = "tpu_inference.mock.vllm_logging_utils.NewLineFormatter"
145
+
146
+ if logging_config:
147
+ dictConfig(logging_config)
148
+
149
+
150
+ # The root logger is initialized when the module is imported.
151
+ # This is thread-safe as the module is only imported once,
152
+ # guaranteed by the Python GIL.
153
+ _configure_vllm_root_logger()
154
+
155
+
156
+ def init_logger(name: str) -> _VllmLogger:
157
+ """The main purpose of this function is to ensure that loggers are
158
+ retrieved in such a way that we can be sure the root vllm logger has
159
+ already been configured."""
160
+
161
+ logger = logging.getLogger(name)
162
+
163
+ for method_name, method in _METHODS_TO_PATCH.items():
164
+ setattr(logger, method_name, MethodType(method, logger))
165
+
166
+ return cast(_VllmLogger, logger)
167
+
168
+
169
+ logger = init_logger(__name__)
170
+
171
+
172
+ def _trace_calls(log_path, root_dir, frame, event, arg=None):
173
+ if event in ['call', 'return']:
174
+ # Extract the filename, line number, function name, and the code object
175
+ filename = frame.f_code.co_filename
176
+ lineno = frame.f_lineno
177
+ func_name = frame.f_code.co_name
178
+ if not filename.startswith(root_dir):
179
+ # only log the functions in the vllm root_dir
180
+ return
181
+ # Log every function call or return
182
+ try:
183
+ last_frame = frame.f_back
184
+ if last_frame is not None:
185
+ last_filename = last_frame.f_code.co_filename
186
+ last_lineno = last_frame.f_lineno
187
+ last_func_name = last_frame.f_code.co_name
188
+ else:
189
+ # initial frame
190
+ last_filename = ""
191
+ last_lineno = 0
192
+ last_func_name = ""
193
+ with open(log_path, 'a') as f:
194
+ ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
195
+ if event == 'call':
196
+ f.write(f"{ts} Call to"
197
+ f" {func_name} in {filename}:{lineno}"
198
+ f" from {last_func_name} in {last_filename}:"
199
+ f"{last_lineno}\n")
200
+ else:
201
+ f.write(f"{ts} Return from"
202
+ f" {func_name} in {filename}:{lineno}"
203
+ f" to {last_func_name} in {last_filename}:"
204
+ f"{last_lineno}\n")
205
+ except NameError:
206
+ # modules are deleted during shutdown
207
+ pass
208
+ return partial(_trace_calls, log_path, root_dir)
209
+
210
+
211
+ def enable_trace_function_call(log_file_path: str, root_dir: str):
212
+ sys.settrace(partial(_trace_calls, log_file_path, root_dir))
@@ -0,0 +1,15 @@
1
+ import logging
2
+
3
+
4
+ class NewLineFormatter(logging.Formatter):
5
+ """Adds logging prefix to newlines to align multi-line messages."""
6
+
7
+ def __init__(self, fmt, datefmt=None, style="%"):
8
+ logging.Formatter.__init__(self, fmt, datefmt, style)
9
+
10
+ def format(self, record):
11
+ msg = logging.Formatter.format(self, record)
12
+ if record.message != "":
13
+ parts = msg.split(record.message)
14
+ msg = msg.replace("\n", "\r\n" + parts[0])
15
+ return msg
File without changes
File without changes