llama-stack 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  30. llama_stack/distributions/nvidia/config.yaml +4 -1
  31. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  32. llama_stack/distributions/oci/config.yaml +4 -1
  33. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  34. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  35. llama_stack/distributions/starter/build.yaml +62 -0
  36. llama_stack/distributions/starter/config.yaml +22 -3
  37. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  38. llama_stack/distributions/starter/starter.py +13 -1
  39. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  40. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  41. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  42. llama_stack/distributions/template.py +10 -2
  43. llama_stack/distributions/watsonx/config.yaml +4 -1
  44. llama_stack/log.py +1 -0
  45. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  46. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  47. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +53 -51
  48. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  49. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  50. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  51. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  52. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  53. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  54. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  55. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  56. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  57. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  58. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  59. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  60. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  61. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  62. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  63. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  64. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  65. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  66. llama_stack/providers/registry/agents.py +1 -0
  67. llama_stack/providers/registry/inference.py +1 -9
  68. llama_stack/providers/registry/vector_io.py +136 -16
  69. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  70. llama_stack/providers/remote/files/s3/config.py +5 -3
  71. llama_stack/providers/remote/files/s3/files.py +2 -2
  72. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  73. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  74. llama_stack/providers/remote/inference/together/together.py +4 -0
  75. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  76. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  77. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  78. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  79. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  80. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  81. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  82. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  83. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  84. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  85. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  86. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  87. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  88. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  89. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  90. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  91. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  92. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  93. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  94. llama_stack/providers/utils/bedrock/client.py +3 -3
  95. llama_stack/providers/utils/bedrock/config.py +7 -7
  96. llama_stack/providers/utils/inference/__init__.py +0 -25
  97. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  98. llama_stack/providers/utils/inference/http_client.py +239 -0
  99. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  100. llama_stack/providers/utils/inference/model_registry.py +148 -2
  101. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  102. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  103. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  104. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  105. llama_stack/providers/utils/memory/vector_store.py +46 -19
  106. llama_stack/providers/utils/responses/responses_store.py +7 -7
  107. llama_stack/providers/utils/safety.py +114 -0
  108. llama_stack/providers/utils/tools/mcp.py +44 -3
  109. llama_stack/testing/api_recorder.py +9 -3
  110. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  111. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/RECORD +115 -148
  112. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  113. llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
  114. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  115. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  116. llama_stack/models/llama/hadamard_utils.py +0 -88
  117. llama_stack/models/llama/llama3/args.py +0 -74
  118. llama_stack/models/llama/llama3/dog.jpg +0 -0
  119. llama_stack/models/llama/llama3/generation.py +0 -378
  120. llama_stack/models/llama/llama3/model.py +0 -304
  121. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  122. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  123. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  124. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  125. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  126. llama_stack/models/llama/llama3/pasta.jpeg +0 -0
  127. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  128. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  129. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  130. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  131. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  132. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  133. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  134. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  135. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  136. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  137. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  138. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  139. llama_stack/models/llama/llama4/args.py +0 -107
  140. llama_stack/models/llama/llama4/ffn.py +0 -58
  141. llama_stack/models/llama/llama4/moe.py +0 -214
  142. llama_stack/models/llama/llama4/preprocess.py +0 -435
  143. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  144. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  145. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  146. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  147. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  148. llama_stack/models/llama/quantize_impls.py +0 -316
  149. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  150. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  151. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  152. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  153. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  154. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  155. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  156. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  157. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  158. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
  159. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,214 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # ruff: noqa: N806
8
- # pyre-strict
9
- from typing import Any
10
-
11
- import fairscale.nn.model_parallel.initialize as fs_init
12
- import torch
13
- from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
14
- from torch import Tensor, nn
15
- from torch.nn import functional as F
16
-
17
- from .args import MoEArgs
18
- from .ffn import FeedForward
19
-
20
-
21
- class Experts(nn.Module):
22
- def __init__(
23
- self,
24
- num_local_experts: int,
25
- dim: int,
26
- hidden_dim: int,
27
- ) -> None:
28
- super().__init__()
29
-
30
- dtype = torch.get_default_dtype()
31
- self.num_local_experts = num_local_experts
32
- self.dim = dim
33
- divide_factor = fs_init.get_model_parallel_world_size()
34
-
35
- self.w1: nn.Parameter = nn.Parameter(
36
- torch.empty(
37
- num_local_experts,
38
- dim,
39
- divide_exact(hidden_dim, divide_factor),
40
- dtype=dtype,
41
- )
42
- )
43
-
44
- self.w2: nn.Parameter = nn.Parameter(
45
- torch.empty(
46
- num_local_experts,
47
- divide_exact(hidden_dim, divide_factor),
48
- dim,
49
- dtype=dtype,
50
- )
51
- )
52
-
53
- self.w3: nn.Parameter = nn.Parameter(
54
- torch.empty(
55
- num_local_experts,
56
- dim,
57
- divide_exact(hidden_dim, divide_factor),
58
- dtype=dtype,
59
- )
60
- )
61
-
62
- self._register_load_state_dict_pre_hook(self.load_hook)
63
-
64
- def load_hook(
65
- self,
66
- state_dict: dict[str, Any],
67
- prefix: str,
68
- local_metadata: dict[str, Any],
69
- strict: bool,
70
- missing_keys: list[str],
71
- unexpected_keys: list[str],
72
- error_msgs: list[str],
73
- ) -> None:
74
- self.prefix = prefix
75
- if prefix + "moe_w_in_eD_F" in state_dict:
76
- e = self.num_local_experts
77
- D = self.dim
78
- state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view(e, D, -1)
79
- state_dict[prefix + "w2"] = state_dict.pop(prefix + "moe_w_out_eF_D").view(e, -1, D)
80
- state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view(e, D, -1)
81
-
82
- def forward(
83
- self,
84
- routed_in_egD: torch.Tensor, # noqa: N803
85
- ) -> torch.Tensor:
86
- e = self.num_local_experts
87
- D = self.dim
88
-
89
- x_egD = routed_in_egD.view(e, -1, D)
90
-
91
- out_egD = self.batched_swiglu(x_egD, self.w1, self.w3, self.w2)
92
- out_egD = out_egD.view(-1, D)
93
-
94
- return out_egD
95
-
96
- def batched_swiglu(self, x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
97
- middle_out_egF = F.silu(torch.bmm(x, w1)) * torch.bmm(x, w3)
98
- return torch.bmm(middle_out_egF, w2)
99
-
100
-
101
- class MoE(torch.nn.Module):
102
- """
103
- Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
104
- Several commonly used annotations include:
105
- - a: bsz*slen
106
- - E: number of experts
107
- - e: number of local experts per ep (n_experts/ep)
108
- - D: hidden dimension
109
- - d: D/tp
110
- - F: model dimension
111
- - G: number of tokens per expert (a * capacity_factor / E)
112
- - g: number of tokens per expert per TP rank (i.e., G/TP)
113
-
114
- Examples:
115
- x_aD [a, D]
116
- routed_in_etG_D [et*G, D]
117
- x_eGD: [e, G, D]
118
- """
119
-
120
- def __init__(
121
- self,
122
- dim: int,
123
- hidden_dim: int,
124
- ffn_dim_multiplier: float,
125
- multiple_of: int,
126
- moe_args: MoEArgs,
127
- ) -> None:
128
- super().__init__()
129
-
130
- self.moe_args = moe_args
131
-
132
- hidden_dim_denom: float = 1
133
- if moe_args.auto_scale_F:
134
- hidden_dim_denom = moe_args.capacity_factor + 1
135
-
136
- hidden_dim = int(2 * hidden_dim / 3)
137
-
138
- # custom dim factor multiplier
139
- hidden_dim = int(ffn_dim_multiplier * hidden_dim)
140
-
141
- if moe_args.auto_scale_F:
142
- hidden_dim = int(hidden_dim / hidden_dim_denom)
143
-
144
- hidden_dim += -hidden_dim % multiple_of
145
-
146
- num_local_experts: int = moe_args.num_experts
147
- dtype: torch.dtype = torch.get_default_dtype()
148
- self.experts = Experts(
149
- num_local_experts,
150
- dim,
151
- hidden_dim,
152
- )
153
-
154
- self.router_DE: nn.Parameter = nn.Parameter(torch.empty(dim, moe_args.num_experts, dtype=dtype))
155
- self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False)
156
-
157
- self._register_load_state_dict_pre_hook(self.load_hook)
158
-
159
- def load_hook(
160
- self,
161
- state_dict: dict[str, Any],
162
- prefix: str,
163
- local_metadata: dict[str, Any],
164
- strict: bool,
165
- missing_keys: list[str],
166
- unexpected_keys: list[str],
167
- error_msgs: list[str],
168
- ) -> None:
169
- if prefix + "w_in_shared_FD.weight" in state_dict:
170
- state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
171
- state_dict[prefix + "shared_expert.w3.weight"] = state_dict.pop(prefix + "w_swiglu_FD.weight")
172
- state_dict[prefix + "shared_expert.w2.weight"] = state_dict.pop(prefix + "w_out_shared_DF.weight")
173
-
174
- def forward(self, x_bsD: Tensor) -> Tensor: # noqa: N803
175
- _, slen, D = x_bsD.shape
176
- x_aD = x_bsD.view(-1, D)
177
-
178
- a = x_aD.shape[0]
179
-
180
- router_scores: Tensor = torch.matmul(x_aD, self.router_DE).transpose(0, 1)
181
-
182
- router_scores_aK, router_indices_aK = torch.topk(router_scores.transpose(0, 1), self.moe_args.top_k, dim=1)
183
- router_scores = (
184
- torch.full_like(router_scores.transpose(0, 1), float("-inf"))
185
- .scatter_(1, router_indices_aK, router_scores_aK)
186
- .transpose(0, 1)
187
- )
188
- router_indices = torch.arange(a, device=x_aD.device).view(1, -1).expand(router_scores.size(0), -1)
189
-
190
- router_scores = torch.sigmoid(router_scores)
191
-
192
- routed_in_EG_D: Tensor = torch.gather(
193
- x_aD,
194
- dim=0,
195
- index=router_indices.reshape(-1, 1).expand(-1, D),
196
- )
197
- routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
198
-
199
- out_aD = self.shared_expert(x_aD)
200
- routed_out_eg_D = self.experts(routed_in_EG_D.detach())
201
-
202
- router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
203
- out_aD.scatter_add_(
204
- dim=0,
205
- index=router_indices_EG_D,
206
- src=routed_out_eg_D.view(-1, D),
207
- )
208
- out_aD = reduce_from_model_parallel_region(out_aD)
209
- return out_aD.view(-1, slen, D)
210
-
211
-
212
- def divide_exact(numerator: int, denominator: int) -> int:
213
- assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
214
- return numerator // denominator
@@ -1,435 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.
13
-
14
- import math
15
- from collections import defaultdict
16
-
17
- import torch
18
- import torchvision.transforms as tv
19
- from PIL import Image, ImageFile
20
- from torchvision.transforms import functional as F
21
-
22
- ImageFile.LOAD_TRUNCATED_IMAGES = True
23
-
24
- IMAGE_RES = 448
25
-
26
-
27
- class ResizeNormalizeImageTransform:
28
- def __init__(
29
- self,
30
- size_width=None,
31
- size_height=None,
32
- ) -> None:
33
- self._size_width = size_width or IMAGE_RES
34
- self._size_height = size_height or IMAGE_RES
35
- self._mean = (0.5, 0.5, 0.5)
36
- self._std = (0.5, 0.5, 0.5)
37
-
38
- self.tv_transform = tv.Compose(
39
- [
40
- tv.Resize((self._size_height, self._size_width)),
41
- tv.ToTensor(),
42
- tv.Normalize(
43
- mean=self._mean,
44
- std=self._std,
45
- inplace=True,
46
- ),
47
- ]
48
- )
49
-
50
- def __call__(self, image: Image.Image) -> torch.Tensor:
51
- return self.tv_transform(image)
52
-
53
-
54
- class VariableSizeImageTransform:
55
- """
56
- This class accepts images of any size and dynamically resize, pads and chunks it
57
- based on the image aspect ratio and the number of image chunks we allow.
58
-
59
- The algorithm will NOT distort the image fit a certain aspect ratio, because
60
- that leads to a significant degradation in image quality.
61
-
62
- It can be summarized in 6 steps:
63
- 1. Find all possible canvas combinations of max_num_chunks;
64
- 2. Find the best canvas to fit the image;
65
- 3. Resize without distortion
66
- 4. Pad
67
- 5. Normalize
68
- 6. Chunk
69
-
70
- For example, if an input image is of size 300x800, patch_size of 224,
71
- and max_num_chunks = 8, it will find the closest aspect ratio that
72
- is allowed within 8 image chunks, with some restrictions.
73
- In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
74
- giving a total of 8 chunks.
75
-
76
- If resize_to_max_canvas, the image will be resized (without distortion),
77
- to the largest possible resolution. In this case, 388:896, and padded to 448:896,
78
- where we maintain the original aspect ratio and pad with zeros value for the rest.
79
- This approach minimizes the amount of padding required for any arbitrary resolution.
80
-
81
- However, if limit_upscaling_to_patch_size is set to True,
82
- the upscaling will be limited to the patch size. In the example above,
83
- the image would remain 300x800 (no upscaling), and then padded to 448:896.
84
-
85
- The final output will therefore be of shape (8, 3, 224, 224), where 2x4
86
- patches are coming from the resizing and chunking.
87
- """
88
-
89
- def __init__(self, size: int = IMAGE_RES) -> None:
90
- self.size = size
91
- self.to_tensor = tv.ToTensor()
92
- self._mean = (0.5, 0.5, 0.5)
93
- self._std = (0.5, 0.5, 0.5)
94
- self.normalize = tv.Normalize(
95
- mean=self._mean,
96
- std=self._std,
97
- inplace=True,
98
- )
99
- self.resample = tv.InterpolationMode.BILINEAR
100
-
101
- @staticmethod
102
- def get_factors(n: int) -> set[int]:
103
- """
104
- Calculate all factors of a given number, i.e. a dividor that leaves
105
- no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
106
-
107
- Args:
108
- n (int): The number to find factors for.
109
-
110
- Returns:
111
- set: A set containing all factors of the number.
112
- """
113
- factors_set = set()
114
-
115
- for i in range(1, int(n**0.5) + 1):
116
- if n % i == 0:
117
- factors_set.add(i)
118
- factors_set.add(n // i)
119
- return factors_set
120
-
121
- def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
122
- """
123
- Computes all of the allowed resoltuions for a fixed number of chunks
124
- and patch_size. Useful for when dividing an image into chunks.
125
-
126
- Args:
127
- max_num_chunks (int): Maximum number of chunks for processing.
128
- patch_size (int): Size of the side of the patch.
129
-
130
- Returns:
131
- torch.Tensor: List of possible resolutions as tuples (height, width).
132
-
133
- Example:
134
- >>> max_num_chunks = 5
135
- >>> patch_size = 224
136
- >>> find_supported_resolutions(max_num_chunks, patch_size)
137
- tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
138
- (672, 224), (224, 448), (448, 224)])
139
-
140
- Given max_num_chunks=4, patch_size=224, it will create a dictionary:
141
- {
142
- 0.25: [(1, 4)],
143
- 1.0: [(2, 2), (1, 1)],
144
- 4.0: [(4, 1)],
145
- 0.33: [(1, 3)],
146
- 3.0: [(3, 1)],
147
- 0.5: [(1, 2)],
148
- 2.0: [(2, 1)]
149
- }
150
-
151
- and return the resolutions multiplied by the patch_size:
152
- [(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
153
- """
154
- asp_dict = defaultdict(list)
155
- for chunk_size in range(max_num_chunks, 0, -1):
156
- _factors = sorted(self.get_factors(chunk_size))
157
- _asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
158
- for height, width in _asp_ratios:
159
- ratio_float = height / width
160
- asp_dict[ratio_float].append((height, width))
161
-
162
- # get the resolutions multiplied by the patch_size
163
- possible_resolutions = []
164
- for value in asp_dict.values():
165
- for height, width in value:
166
- possible_resolutions.append((height * patch_size, width * patch_size))
167
-
168
- return possible_resolutions
169
-
170
- @staticmethod
171
- def get_max_res_without_distortion(
172
- image_size: tuple[int, int],
173
- target_size: tuple[int, int],
174
- ) -> tuple[int, int]:
175
- """
176
- Determines the maximum resolution to which an image can be resized to without distorting its
177
- aspect ratio, based on the target resolution.
178
-
179
- Args:
180
- image_size (Tuple[int, int]): The original resolution of the image (height, width).
181
- target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
182
- Returns:
183
- Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
184
- Example:
185
- >>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
186
- (134, 200)
187
- >>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
188
- (450, 338)
189
- """
190
-
191
- original_width, original_height = image_size
192
- target_width, target_height = target_size
193
-
194
- scale_w = target_width / original_width
195
- scale_h = target_height / original_height
196
-
197
- if scale_w < scale_h:
198
- new_width = target_width
199
- new_height = min(math.floor(original_height * scale_w), target_height)
200
- else:
201
- new_height = target_height
202
- new_width = min(math.floor(original_width * scale_h), target_width)
203
-
204
- return new_width, new_height
205
-
206
- def _pad(self, image: Image.Image, target_size) -> Image.Image:
207
- new_width, new_height = target_size
208
- new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
209
- new_im.paste(image)
210
- return new_im
211
-
212
- def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
213
- # Split image into number of required tiles (width x height)
214
- num_channels, height, width = image.size()
215
- image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
216
- # Permute dimensions to reorder the axes
217
- image = image.permute(1, 3, 0, 2, 4).contiguous()
218
- # Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
219
- image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
220
- return image
221
-
222
- def resize_without_distortion(
223
- self,
224
- image: torch.Tensor,
225
- target_size: tuple[int, int],
226
- max_upscaling_size: int | None,
227
- ) -> torch.Tensor:
228
- """
229
- Used to resize an image to target_resolution, without distortion.
230
-
231
- If target_size requires upscaling the image, the user can set max_upscaling_size to
232
- limit the upscaling to a maximum size. In this case, since we rescale without distortion,
233
- modifying target_size works as a boundary for the image's largest side.
234
-
235
- Args:
236
- resample (str): Resampling method used when resizing images.
237
- Supports "nearest", "nearest_exact", "bilinear", "bicubic".
238
- max_upscaling_size (int): The maximum size to upscale the image to.
239
- If None, there is no limit.
240
- Examples:
241
- >>> target_size = (1000, 1200)
242
- >>> max_upscaling_size = 600
243
- >>> image_size = (400, 200)
244
- >>> resize_without_distortion(image_size, target_size, max_upscaling_size)
245
- (600, 300) # new_size_without_distortion
246
-
247
- >>> target_size = (1000, 1200)
248
- >>> max_upscaling_size = 600
249
- >>> image_size = (2000, 200)
250
- >>> resize_without_distortion(image_size, target_size, max_upscaling_size)
251
- (1000, 100) # new_size_without_distortion
252
-
253
- >>> target_size = (1000, 1200)
254
- >>> max_upscaling_size = 2000
255
- >>> image_size = (400, 200)
256
- >>> resize_without_distortion(image_size, target_size, max_upscaling_size)
257
- (1000, 500) # new_size_without_distortion
258
-
259
- >>> target_size = (1000, 1200)
260
- >>> max_upscaling_size = None
261
- >>> image_size = (400, 200)
262
- >>> resize_without_distortion(image_size, target_size, max_upscaling_size)
263
- (1000, 500) # new_size_without_distortion
264
- """
265
-
266
- image_width, image_height = image.size
267
- image_size = (image_width, image_height)
268
-
269
- # If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
270
- if max_upscaling_size is not None:
271
- new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
272
- new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
273
- target_size = (new_target_width, new_target_height)
274
-
275
- # resize to target_size while preserving aspect ratio
276
- new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
277
-
278
- image = F.resize(
279
- image,
280
- (
281
- max(new_size_without_distortion[1], 1),
282
- max(new_size_without_distortion[0], 1),
283
- ),
284
- interpolation=self.resample,
285
- )
286
-
287
- return image
288
-
289
- def get_best_fit(
290
- self,
291
- image_size: tuple[int, int],
292
- possible_resolutions: torch.Tensor,
293
- resize_to_max_canvas: bool = False,
294
- ) -> tuple[int, int]:
295
- """
296
- Determines the best canvas possible from a list of possible resolutions to, without distortion,
297
- resize an image to.
298
-
299
- For each possible resolution, calculates the scaling factors for
300
- width and height, and selects the smallest one, which is the limiting side.
301
- E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
302
- therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
303
-
304
- If upscaling is possible (any of the scaling factors is greater than 1),
305
- then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
306
-
307
- If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
308
- reduce downscaling as much as possible.
309
-
310
- If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
311
- to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
312
- has more padding.
313
-
314
- Args:
315
- image_size (Tuple[int, int]): A tuple containing the height and width of the image.
316
- possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
317
- row represents a possible resolution (height, width).
318
- use_max_upscaling (bool): If True, will return the largest upscaling resolution.
319
-
320
- Returns:
321
- List[int]: The best resolution [height, width] for the given image.
322
-
323
- Example:
324
- >>> image_size = (200, 300)
325
- >>> possible_resolutions = torch.tensor([[224, 672],
326
- ... [672, 224],
327
- ... [224, 448],
328
- ... [448, 224],
329
- ... [224, 224]])
330
- >>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
331
- [224, 448]
332
-
333
- We have:
334
- scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
335
- scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
336
- scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
337
- Only one of the scales > 1:
338
- upscaling_possible = tensor([1.1200, 1.1200])
339
- smallest_rescale = tensor(1.1200)
340
- So we pick the resolution with the smallest smallest area:
341
- areas = tensor([150528, 100352]) # [672, 224], [224, 448]
342
- optimal_canvas = tensor([224, 448])
343
- """
344
-
345
- original_width, original_height = image_size
346
-
347
- # get all possible resolutions heights/widths
348
- target_widths, target_heights = (
349
- possible_resolutions[:, 0],
350
- possible_resolutions[:, 1],
351
- )
352
-
353
- # get scaling factors to resize the image without distortion
354
- scale_w = target_widths / original_width
355
- scale_h = target_heights / original_height
356
-
357
- # get the min scale between width and height (limiting side -> no distortion)
358
- scales = torch.where(scale_w > scale_h, scale_h, scale_w)
359
-
360
- # filter only scales that allow upscaling
361
- upscaling_options = scales[scales >= 1]
362
- if len(upscaling_options) > 0:
363
- if resize_to_max_canvas:
364
- selected_scale = torch.max(upscaling_options)
365
- else:
366
- selected_scale = torch.min(upscaling_options)
367
- else:
368
- # no upscaling possible,
369
- # get the minimum downscaling (max scale for scales<1)
370
- downscaling_options = scales[scales < 1]
371
- selected_scale = torch.max(downscaling_options)
372
-
373
- # get all resolutions that support this scaling factor,
374
- # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
375
- chosen_canvas = possible_resolutions[scales == selected_scale]
376
-
377
- # if there are multiple resolutions,
378
- # get the one with minimum area to reduce padding
379
- if len(chosen_canvas) > 1:
380
- areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
381
- optimal_idx = torch.argmin(areas)
382
- optimal_canvas = chosen_canvas[optimal_idx]
383
- else:
384
- optimal_canvas = chosen_canvas[0]
385
-
386
- return tuple(optimal_canvas.tolist())
387
-
388
- def __call__(
389
- self,
390
- image: Image.Image,
391
- max_num_chunks: int,
392
- normalize_img: bool = True,
393
- resize_to_max_canvas: bool = False,
394
- ) -> tuple[torch.Tensor, tuple[int, int]]:
395
- """
396
- Args:
397
- image (PIL.Image): Image to be resized.
398
- max_num_chunks (int): Maximum number of chunks to split the image into.
399
- normalize_img (bool): Whether to normalize the image.
400
- resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
401
- If True, picks the canvas the allows the largest resizing without distortion.
402
- If False, downsample as little as possible, including no resizing at all,
403
- but never upsample, unless the image is smaller than the patch size.
404
- """
405
- assert max_num_chunks > 0
406
- assert isinstance(image, Image.Image), type(image)
407
- w, h = image.size
408
-
409
- possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
410
- possible_resolutions = torch.tensor(possible_resolutions)
411
-
412
- best_resolution = self.get_best_fit(
413
- image_size=(w, h),
414
- possible_resolutions=possible_resolutions,
415
- resize_to_max_canvas=resize_to_max_canvas,
416
- )
417
-
418
- max_upscaling_size = None if resize_to_max_canvas else self.size
419
- image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
420
- image = self._pad(image, best_resolution)
421
-
422
- image = self.to_tensor(image)
423
-
424
- if normalize_img:
425
- image = self.normalize(image)
426
-
427
- ratio_w, ratio_h = (
428
- best_resolution[0] // self.size,
429
- best_resolution[1] // self.size,
430
- )
431
-
432
- image = self._split(image, ratio_w, ratio_h) # type: ignore
433
-
434
- ar = (ratio_h, ratio_w)
435
- return image, ar
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.