liger-kernel 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +3 -0
- liger_kernel/chunked_loss/cpo_loss.py +18 -8
- liger_kernel/chunked_loss/dpo_loss.py +20 -10
- liger_kernel/chunked_loss/functional.py +4 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
- liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
- liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
- liger_kernel/chunked_loss/grpo_loss.py +160 -0
- liger_kernel/chunked_loss/jsd_loss.py +154 -0
- liger_kernel/chunked_loss/kto_loss.py +172 -0
- liger_kernel/chunked_loss/orpo_loss.py +8 -9
- liger_kernel/chunked_loss/simpo_loss.py +22 -8
- liger_kernel/env_report.py +5 -12
- liger_kernel/ops/cross_entropy.py +102 -51
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
- liger_kernel/ops/fused_linear_jsd.py +14 -32
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +5 -9
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +23 -12
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +14 -32
- liger_kernel/ops/rope.py +31 -33
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +3 -2
- liger_kernel/transformers/__init__.py +19 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +7 -9
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +28 -7
- liger_kernel/transformers/fused_linear_cross_entropy.py +15 -10
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +9 -15
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +24 -54
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/olmo2.py +124 -0
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +36 -32
- liger_kernel/transformers/monkey_patch.py +214 -144
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/rope.py +2 -2
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +49 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/METADATA +53 -26
- liger_kernel-0.5.4.dist-info/RECORD +74 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/WHEEL +1 -1
- liger_kernel-0.5.2.dist-info/RECORD +0 -65
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFusedLinearUnpairedPreferenceBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
8
|
+
@staticmethod
|
|
9
|
+
def preference_loss_fn(
|
|
10
|
+
average_log_prob_chunk,
|
|
11
|
+
preference_labels_chunk,
|
|
12
|
+
full_target,
|
|
13
|
+
ref_average_log_prob_chunk=None,
|
|
14
|
+
beta=0.1,
|
|
15
|
+
kl=None,
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Implements the Kahneman-Tversky Optimization (KTO) loss function.
|
|
19
|
+
Paper: "KTO: Model Alignment as Prospect Theory-Guided Optimization"
|
|
20
|
+
https://arxiv.org/abs/2402.01306
|
|
21
|
+
|
|
22
|
+
KTO loss is inspired by prospect theory (https://en.wikipedia.org/wiki/Prospect_theory)
|
|
23
|
+
from behavioral economics, which models how humans make decisions under uncertainty.
|
|
24
|
+
The loss function is asymmetric, treating gains and losses differently, similar to
|
|
25
|
+
human decision-making patterns.
|
|
26
|
+
|
|
27
|
+
Formula:
|
|
28
|
+
When y is chosen:
|
|
29
|
+
L_KTO = 1 - σ(β * (log[π(x)/π₀(x)] - KL(π||π₀)_y))
|
|
30
|
+
When y is rejected:
|
|
31
|
+
L_KTO = 1 - σ(β * (KL(π||π₀)_y - log[π(x)/π₀(x)]))
|
|
32
|
+
|
|
33
|
+
Where:
|
|
34
|
+
- σ: Sigmoid function
|
|
35
|
+
- β: Temperature parameter controlling the strength of the preference signal
|
|
36
|
+
- π(x): Policy (current model)
|
|
37
|
+
- π₀(x): Reference policy (reference model)
|
|
38
|
+
- KL(π||π₀)_y: KL divergence estimated using the rejected response y
|
|
39
|
+
|
|
40
|
+
The loss encourages the model to:
|
|
41
|
+
1. Assign higher probability to chosen responses
|
|
42
|
+
2. Assign lower probability to rejected responses
|
|
43
|
+
3. Maintain reasonable distance from the reference model
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
average_log_prob_chunk: Log probabilities for the chunk (batch_size,)
|
|
47
|
+
preference_labels_chunk: Preference labels for the chunk (batch_size,)
|
|
48
|
+
full_target: Non chunked full target tensor
|
|
49
|
+
ref_average_log_prob_chunk: Reference log probs for the chunk (batch_size,)
|
|
50
|
+
beta: Weight for the KTO loss
|
|
51
|
+
kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
|
|
52
|
+
Returns:
|
|
53
|
+
- loss: The KTO loss value
|
|
54
|
+
"""
|
|
55
|
+
if ref_average_log_prob_chunk is not None:
|
|
56
|
+
logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
|
|
57
|
+
else:
|
|
58
|
+
logratios_chunk = average_log_prob_chunk
|
|
59
|
+
|
|
60
|
+
multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
|
|
61
|
+
if kl is not None:
|
|
62
|
+
losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
|
|
63
|
+
else:
|
|
64
|
+
losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)
|
|
65
|
+
|
|
66
|
+
return losses.sum() / (full_target.shape[0])
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def forward(
|
|
70
|
+
ctx,
|
|
71
|
+
_input,
|
|
72
|
+
weight,
|
|
73
|
+
target,
|
|
74
|
+
preference_labels,
|
|
75
|
+
bias=None,
|
|
76
|
+
ref_input=None,
|
|
77
|
+
ref_weight=None,
|
|
78
|
+
ref_bias=None,
|
|
79
|
+
kl=None,
|
|
80
|
+
ignore_index=-100,
|
|
81
|
+
beta=0.1,
|
|
82
|
+
compiled=True,
|
|
83
|
+
use_ref_model=True,
|
|
84
|
+
):
|
|
85
|
+
return LigerFusedLinearUnpairedPreferenceBase.forward(
|
|
86
|
+
ctx=ctx,
|
|
87
|
+
_input=_input,
|
|
88
|
+
weight=weight,
|
|
89
|
+
target=target,
|
|
90
|
+
preference_labels=preference_labels,
|
|
91
|
+
bias=bias,
|
|
92
|
+
loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn,
|
|
93
|
+
ignore_index=ignore_index,
|
|
94
|
+
beta=beta,
|
|
95
|
+
compiled=compiled,
|
|
96
|
+
use_ref_model=use_ref_model,
|
|
97
|
+
ref_input=ref_input,
|
|
98
|
+
ref_weight=ref_weight,
|
|
99
|
+
ref_bias=ref_bias,
|
|
100
|
+
kl=kl,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def backward(ctx, *grad_output):
|
|
105
|
+
grads = LigerFusedLinearUnpairedPreferenceBase.backward(ctx, grad_output)[:5]
|
|
106
|
+
return (
|
|
107
|
+
*grads,
|
|
108
|
+
None,
|
|
109
|
+
None,
|
|
110
|
+
None,
|
|
111
|
+
None,
|
|
112
|
+
None,
|
|
113
|
+
None,
|
|
114
|
+
None,
|
|
115
|
+
None,
|
|
116
|
+
None,
|
|
117
|
+
None,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
122
|
+
"""
|
|
123
|
+
Fused linear layer with Kahneman-Tversky Optimization (KTO) loss.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
ignore_index: int = -100,
|
|
129
|
+
beta: float = 0.1,
|
|
130
|
+
compiled: bool = True,
|
|
131
|
+
use_ref_model: bool = False,
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
Args:
|
|
135
|
+
ignore_index (int): Index to ignore in the loss calculation
|
|
136
|
+
beta (float): Temperature parameter for the KTO loss
|
|
137
|
+
compiled (bool): Whether to use compiled operations
|
|
138
|
+
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
139
|
+
"""
|
|
140
|
+
super().__init__()
|
|
141
|
+
self.ignore_index = ignore_index
|
|
142
|
+
self.beta = beta
|
|
143
|
+
self.compiled = compiled
|
|
144
|
+
self.use_ref_model = use_ref_model
|
|
145
|
+
|
|
146
|
+
def forward(
|
|
147
|
+
self,
|
|
148
|
+
_input,
|
|
149
|
+
lin_weight,
|
|
150
|
+
target,
|
|
151
|
+
bias=None,
|
|
152
|
+
preference_labels=None,
|
|
153
|
+
ref_input=None,
|
|
154
|
+
ref_weight=None,
|
|
155
|
+
ref_bias=None,
|
|
156
|
+
kl=None,
|
|
157
|
+
):
|
|
158
|
+
return LigerFusedLinearKTOFunction.apply(
|
|
159
|
+
_input,
|
|
160
|
+
lin_weight,
|
|
161
|
+
target,
|
|
162
|
+
preference_labels,
|
|
163
|
+
bias,
|
|
164
|
+
ref_input,
|
|
165
|
+
ref_weight,
|
|
166
|
+
ref_bias,
|
|
167
|
+
kl,
|
|
168
|
+
self.ignore_index,
|
|
169
|
+
self.beta,
|
|
170
|
+
self.compiled,
|
|
171
|
+
self.use_ref_model,
|
|
172
|
+
)
|
|
@@ -1,13 +1,10 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn.functional as F
|
|
3
3
|
|
|
4
|
-
from liger_kernel.chunked_loss.fused_linear_preference import
|
|
5
|
-
LigerFusedLinearPreferenceBase,
|
|
6
|
-
)
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
10
|
-
|
|
11
8
|
@staticmethod
|
|
12
9
|
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
|
13
10
|
"""
|
|
@@ -32,11 +29,10 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
32
29
|
beta (float): Weight for the odds ratio loss.
|
|
33
30
|
"""
|
|
34
31
|
log_odds = (chosen_logps - rejected_logps) - (
|
|
35
|
-
torch.log1p(-torch.exp(chosen_logps))
|
|
36
|
-
- torch.log1p(-torch.exp(rejected_logps))
|
|
32
|
+
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
|
|
37
33
|
)
|
|
38
34
|
ratio = F.logsigmoid(log_odds)
|
|
39
|
-
loss = beta * ratio.sum() / (full_target.shape[0] // 2)
|
|
35
|
+
loss = -beta * ratio.sum() / (full_target.shape[0] // 2)
|
|
40
36
|
|
|
41
37
|
chosen_rewards = beta * chosen_logps
|
|
42
38
|
rejected_rewards = beta * rejected_logps
|
|
@@ -56,6 +52,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
56
52
|
ignore_index=-100,
|
|
57
53
|
beta=0.1,
|
|
58
54
|
compute_nll_loss=True,
|
|
55
|
+
nll_target=None,
|
|
59
56
|
compiled=True,
|
|
60
57
|
):
|
|
61
58
|
return LigerFusedLinearPreferenceBase.forward(
|
|
@@ -68,13 +65,14 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
68
65
|
ignore_index=ignore_index,
|
|
69
66
|
beta=beta,
|
|
70
67
|
compute_nll_loss=compute_nll_loss,
|
|
68
|
+
nll_target=nll_target,
|
|
71
69
|
compiled=compiled,
|
|
72
70
|
)
|
|
73
71
|
|
|
74
72
|
@staticmethod
|
|
75
73
|
def backward(ctx, *grad_output):
|
|
76
74
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
77
|
-
return *grads, None, None, None, None
|
|
75
|
+
return *grads, None, None, None, None, None
|
|
78
76
|
|
|
79
77
|
|
|
80
78
|
class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
@@ -100,7 +98,7 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
|
100
98
|
self.compute_nll_loss = compute_nll_loss
|
|
101
99
|
self.compiled = compiled
|
|
102
100
|
|
|
103
|
-
def forward(self, lin_weight, _input, target, bias=None):
|
|
101
|
+
def forward(self, lin_weight, _input, target, bias=None, nll_target=None):
|
|
104
102
|
return LigerFusedLinearORPOFunction.apply(
|
|
105
103
|
_input,
|
|
106
104
|
lin_weight,
|
|
@@ -109,5 +107,6 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
|
109
107
|
self.ignore_index,
|
|
110
108
|
self.beta,
|
|
111
109
|
self.compute_nll_loss,
|
|
110
|
+
nll_target,
|
|
112
111
|
self.compiled,
|
|
113
112
|
)
|
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn.functional as F
|
|
3
3
|
|
|
4
|
-
from liger_kernel.chunked_loss.fused_linear_preference import
|
|
5
|
-
LigerFusedLinearPreferenceBase,
|
|
6
|
-
)
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
10
|
-
|
|
11
8
|
@staticmethod
|
|
12
9
|
def preference_loss_fn(
|
|
13
|
-
chosen_logps,
|
|
10
|
+
chosen_logps,
|
|
11
|
+
rejected_logps,
|
|
12
|
+
full_target,
|
|
13
|
+
beta=0.1,
|
|
14
|
+
gamma=0.5,
|
|
15
|
+
label_smoothing=0.0,
|
|
14
16
|
):
|
|
15
17
|
"""
|
|
16
18
|
Paper: https://arxiv.org/pdf/2405.14734
|
|
@@ -33,10 +35,17 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
33
35
|
full_target: Non chunked full target tensor
|
|
34
36
|
beta (float): beta weight
|
|
35
37
|
gamma (float): gemma margin term
|
|
38
|
+
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
|
|
36
39
|
"""
|
|
37
40
|
logits = beta * (chosen_logps - rejected_logps) - gamma
|
|
38
|
-
loss = F.logsigmoid(logits).sum() / (
|
|
39
|
-
|
|
41
|
+
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
|
|
42
|
+
full_target.shape[0] // 2
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
chosen_rewards = beta * chosen_logps
|
|
46
|
+
rejected_rewards = beta * rejected_logps
|
|
47
|
+
|
|
48
|
+
return loss, chosen_rewards, rejected_rewards
|
|
40
49
|
|
|
41
50
|
@staticmethod
|
|
42
51
|
def forward(
|
|
@@ -48,6 +57,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
48
57
|
ignore_index=-100,
|
|
49
58
|
beta=0.1,
|
|
50
59
|
alpha=1.0,
|
|
60
|
+
label_smoothing=0.0,
|
|
51
61
|
compute_nll_loss=False,
|
|
52
62
|
compiled=True,
|
|
53
63
|
gamma=0.5,
|
|
@@ -63,6 +73,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
63
73
|
ignore_index=ignore_index,
|
|
64
74
|
alpha=alpha,
|
|
65
75
|
beta=beta,
|
|
76
|
+
label_smoothing=label_smoothing,
|
|
66
77
|
compiled=compiled,
|
|
67
78
|
gamma=gamma,
|
|
68
79
|
)
|
|
@@ -70,7 +81,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
70
81
|
@staticmethod
|
|
71
82
|
def backward(ctx, *grad_output):
|
|
72
83
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
73
|
-
return *grads, None, None, None, None, None, None
|
|
84
|
+
return *grads, None, None, None, None, None, None, None
|
|
74
85
|
|
|
75
86
|
|
|
76
87
|
class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
@@ -83,6 +94,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
83
94
|
ignore_index: int = -100,
|
|
84
95
|
beta: float = 0.1,
|
|
85
96
|
alpha: float = 1.0,
|
|
97
|
+
label_smoothing: float = 0.0,
|
|
86
98
|
compute_nll_loss: bool = True,
|
|
87
99
|
compiled: bool = True,
|
|
88
100
|
gamma: float = 0.5,
|
|
@@ -96,6 +108,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
96
108
|
self.ignore_index = ignore_index
|
|
97
109
|
self.beta = beta
|
|
98
110
|
self.alpha = alpha
|
|
111
|
+
self.label_smoothing = label_smoothing
|
|
99
112
|
self.compute_nll_loss = compute_nll_loss
|
|
100
113
|
self.compiled = compiled
|
|
101
114
|
self.gamma = gamma
|
|
@@ -109,6 +122,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
109
122
|
self.ignore_index,
|
|
110
123
|
self.beta,
|
|
111
124
|
self.alpha,
|
|
125
|
+
self.label_smoothing,
|
|
112
126
|
self.compute_nll_loss,
|
|
113
127
|
self.compiled,
|
|
114
128
|
self.gamma,
|
liger_kernel/env_report.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
import platform
|
|
2
2
|
import sys
|
|
3
|
+
|
|
3
4
|
from importlib.metadata import version
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
def print_env_report():
|
|
7
8
|
"""
|
|
8
9
|
|
|
9
|
-
Prints a report of the environment.
|
|
10
|
+
Prints a report of the environment. Useful for debugging and reproducibility.
|
|
10
11
|
Usage:
|
|
11
12
|
```
|
|
12
13
|
python -m liger_kernel.env_report
|
|
@@ -27,15 +28,9 @@ def print_env_report():
|
|
|
27
28
|
import torch
|
|
28
29
|
|
|
29
30
|
print(f"PyTorch version: {torch.__version__}")
|
|
30
|
-
cuda_version = (
|
|
31
|
-
torch.version.cuda if torch.cuda.is_available() else "Not available"
|
|
32
|
-
)
|
|
31
|
+
cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available"
|
|
33
32
|
print(f"CUDA version: {cuda_version}")
|
|
34
|
-
hip_version = (
|
|
35
|
-
torch.version.hip
|
|
36
|
-
if torch.cuda.is_available() and torch.version.hip
|
|
37
|
-
else "Not available"
|
|
38
|
-
)
|
|
33
|
+
hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available"
|
|
39
34
|
print(f"HIP(ROCm) version: {hip_version}")
|
|
40
35
|
|
|
41
36
|
except ImportError:
|
|
@@ -58,9 +53,7 @@ def print_env_report():
|
|
|
58
53
|
print("Transformers: Not installed")
|
|
59
54
|
|
|
60
55
|
try:
|
|
61
|
-
xpu_version = (
|
|
62
|
-
torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
|
|
63
|
-
)
|
|
56
|
+
xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
|
|
64
57
|
print(f"XPU version: {xpu_version}")
|
|
65
58
|
except ImportError:
|
|
66
59
|
print("XPU version: Unable to query")
|