积累日常工作中用到的pytorch脚本片段,以及一些强大但难以理解的函数的解释
def get_parameter_number(model):
total_num = sum(p.numel() for p in model.parameters())
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
def set_weight_decay(model, decay_rate=0.01, lr=0.0001, optimizer_class=AdamW):
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay_rate': decay_rate},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
'weight_decay_rate': 0.0}]
optimizer = optimizer_class(params=optimizer_grouped_parameters, lr=lr)
return optimizer
class FocalLoss(torch.nn.Module):
"""
Softmax and sigmoid focal loss.
copy from https://github.com/lonePatient/TorchBlocks
"""
def __init__(self, num_labels, activation_type='softmax',
gamma=2.0, alpha=0.25, epsilon=1.e-9):
super(FocalLoss, self).__init__()
self.num_labels = num_labels # only for multiclass
self.gamma = gamma
self.alpha = alpha
self.epsilon = epsilon
self.activation_type = activation_type
def forward(self, input, target, reduction="None", is_logits=True):
"""
Args:
logits: model's output, shape of [batch_size, num_cls]
target: ground truth labels, shape of [batch_size]
Returns:
shape of [batch_size]
"""
if self.activation_type == 'softmax':
idx = target.view(-1, 1).long()
one_hot_key = torch.zeros(
idx.size(0), self.num_labels, dtype=torch.float32, device=idx.device)
one_hot_key = one_hot_key.scatter_(1, idx, 1)
if is_logits:
logits = torch.softmax(input, dim=-1)
else: # already normalized
logits = input
loss = -self.alpha * one_hot_key * \
torch.pow((1 - logits), self.gamma) * \
(logits + self.epsilon).log()
loss = loss.sum(1)
elif self.activation_type == 'sigmoid':
multi_hot_key = target
if is_logits:
logits = torch.sigmoid(input)
else:
logits = input
zero_hot_key = 1 - multi_hot_key
loss = -self.alpha * multi_hot_key * \
torch.pow((1 - logits), self.gamma) * \
(logits + self.epsilon).log()
loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits,
self.gamma) * (1 - logits + self.epsilon).log()
if reduction == "None":
return loss
return loss.mean()