pytorch有用的脚本片段和函数的积累


积累日常工作中用到的pytorch脚本片段,以及一些强大但难以理解的函数的解释

查看模型参数量

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

设置weight decay

def set_weight_decay(model, decay_rate=0.01, lr=0.0001, optimizer_class=AdamW):
    
    param_optimizer = list(model.named_parameters())

    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate': decay_rate},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate': 0.0}]
    optimizer = optimizer_class(params=optimizer_grouped_parameters, lr=lr)
    return optimizer

focalloss

class FocalLoss(torch.nn.Module):
    """
    Softmax and sigmoid focal loss.
    copy from https://github.com/lonePatient/TorchBlocks
    """

    def __init__(self, num_labels, activation_type='softmax',
                 gamma=2.0, alpha=0.25, epsilon=1.e-9):

        super(FocalLoss, self).__init__()
        self.num_labels = num_labels # only for multiclass
        self.gamma = gamma
        self.alpha = alpha
        self.epsilon = epsilon
        self.activation_type = activation_type

    def forward(self, input, target, reduction="None", is_logits=True):
        """
        Args:
            logits: model's output, shape of [batch_size, num_cls]
            target: ground truth labels, shape of [batch_size]
        Returns:
            shape of [batch_size]
        """
        if self.activation_type == 'softmax':
            idx = target.view(-1, 1).long()
            one_hot_key = torch.zeros(
                idx.size(0), self.num_labels, dtype=torch.float32, device=idx.device)
            one_hot_key = one_hot_key.scatter_(1, idx, 1)
            if is_logits:
                logits = torch.softmax(input, dim=-1)
            else:  # already normalized
                logits = input
            loss = -self.alpha * one_hot_key * \
                torch.pow((1 - logits), self.gamma) * \
                (logits + self.epsilon).log()
            loss = loss.sum(1)
        elif self.activation_type == 'sigmoid':
            multi_hot_key = target
            if is_logits:
                logits = torch.sigmoid(input)
            else:
                logits = input 
            zero_hot_key = 1 - multi_hot_key
            loss = -self.alpha * multi_hot_key * \
                torch.pow((1 - logits), self.gamma) * \
                (logits + self.epsilon).log()
            loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits,
                                                                 self.gamma) * (1 - logits + self.epsilon).log()
        if reduction == "None":
            return loss

        return loss.mean()

原创文章,转载请注明出处,否则拒绝转载!
本文链接:抬头看浏览器地址栏

上篇: 文本tokenize方法总结
下篇: BLEU的计算