Focal Loss 的Pytorch 实现以及实验

   2023-03-08 学习力777
核心提示: Focal loss 是 文章 Focal Loss for Dense Object Detection 中提出对简单样本的进行decay的一种损失函数。是对标准的Cross Entropy Loss 的一种改进。 F L对于简单样本(p比较大)回应较小的loss。如论文中的图1, 在p=0.6时, 标准的CE然后又较大的loss

Focal loss 是 文章 Focal Loss for Dense Object Detection 中提出对简单样本的进行decay的一种损失函数。是对标准的Cross Entropy Loss 的一种改进。 F L对于简单样本(p比较大)回应较小的loss。

如论文中的图1, 在p=0.6时, 标准的CE然后又较大的loss, 但是对于FL就有相对较小的loss回应。这样就是对简单样本的一种decay。其中alpha 是对每个类别在训练数据中的频率有关, 但是下面的实现我们是基于alpha=1进行实验的。

Focal Loss 的Pytorch 实现以及实验

标准的Cross Entropy 为:

Focal Loss 的Pytorch 实现以及实验

Focal Loss 为:

Focal Loss 的Pytorch 实现以及实验

Focal Loss 的Pytorch 实现以及实验

其中 Focal Loss 的Pytorch 实现以及实验

以上公式为下面实现代码的基础。

 

采用基于pytorch 的yolo2 在VOC的上的实验结果如下:

 

Focal Loss 的Pytorch 实现以及实验

在单纯的替换了CrossEntropyLoss之后就有1个点左右的提升。效果还是比较显著的。本实验中采用的是darknet19, 要是采用更大的网络就可能会有更好的性能提升。这个实验结果已经能很好的说明的Focal Loss 的对于检测的价值了。

 

一点没做的但是可能会提升性能:

1. 采用soft - gamma: 在训练的过程中阶段性的增大gamma 可能会有更好的性能提升

 

 

本文实验中采用的Focal Loss 代码如下。

关于Focal Loss 的数学推倒在文章:Focal Loss 的前向与后向公式推导

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    r"""
        This criterion is a implemenation of Focal Loss, which is proposed in 
        Focal Loss for Dense Object Detection.

            Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

        The losses are averaged across observations for each minibatch.

        Args:
            alpha(1D Tensor, Variable) : the scalar factor for this criterion
            gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 
                                   putting more focus on hard, misclassified examples
            size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                However, if the field size_average is set to False, the losses are
                                instead summed for each minibatch.


    """
    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average

    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs)

        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)
        #print(class_mask)


        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]

        probs = (P*class_mask).sum(1).view(-1,1)

        log_p = probs.log()
        #print('probs size= {}'.format(probs.size()))
        #print(probs)

        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
        #print('-----bacth_loss------')
        #print(batch_loss)


        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

 

 
反对 0举报 0
 

免责声明:本文仅代表作者个人观点,与乐学笔记(本网)无关。其原创性以及文中陈述文字和内容未经本站证实,对本文以及其中全部或者部分内容、文字的真实性、完整性、及时性本站不作任何保证或承诺,请读者仅作参考,并请自行核实相关内容。
    本网站有部分内容均转载自其它媒体,转载目的在于传递更多信息,并不代表本网赞同其观点和对其真实性负责,若因作品内容、知识产权、版权和其他问题,请及时提供相关证明等材料并与我们留言联系,本网站将在规定时间内给予删除等相关处理.

  • 基于pytorch框架的图像分类实践(CIFAR-10数据集)
    基于pytorch框架的图像分类实践(CIFAR-10数据集
    在学习pytorch的过程中我找到了关于图像分类的很浅显的一个教程上一次做的是pytorch的手写数字图片识别是灰度图片,这次是彩色图片的分类,觉得对于像我这样的刚刚开始入门pytorch的小白来说很有意义,今天写篇关于这个图像分类的博客.收获的知识1.torchvison
    03-08
  • 今天来捋一捋pytorch官方Faster R-CNN代码
    今天来捋一捋pytorch官方Faster R-CNN代码
    AI编辑:我是小将本文作者:白裳https://zhuanlan.zhihu.com/p/145842317本文已由原作者授权 目前 pytorch 已经在 torchvision 模块集成了 FasterRCNN 和 MaskRCNN 代码。考虑到帮助各位小伙伴理解模型细节问题,本文分析一下 FasterRCNN 代码,帮助新手理解
    03-08
  • 从零搭建Pytorch模型教程(三)搭建Transformer网络
    从零搭建Pytorch模型教程(三)搭建Transformer
    ​前言 本文介绍了Transformer的基本流程,分块的两种实现方式,Position Emebdding的几种实现方式,Encoder的实现方式,最后分类的两种方式,以及最重要的数据格式的介绍。 本文来自公众号CV技术指南的技术总结系列欢迎关注公众号CV技术指南,专注于计算机
    03-08
  • 几种网络LeNet、VGG Net、ResNet原理及PyTorch实现
    几种网络LeNet、VGG Net、ResNet原理及PyTorch
    LeNet比较经典,就从LeNet开始,其PyTorch实现比较简单,通过LeNet为基础引出下面的VGG-Net和ResNet。LeNetLeNet比较经典的一张图如下图LeNet-5共有7层,不包含输入,每层都包含可训练参数;每个层有多个Feature Map,每个FeatureMap通过一种卷积滤波器提取输
    03-08
  • Pytorch-基础入门之ANN pytorch零基础入门
    在这部分中来介绍下ANN的Pytorch,这里的ANN具有三个隐含层。这一块的话与上一篇逻辑斯蒂回归使用的是相同的数据集MNIST。第一部分:构造模型# Import Librariesimport torchimport torch.nn as nnfrom torch.autograd import Variable# Create ANN Modelclas
    03-08
  • 分享一个PyTorch医学图像分割开源库 python医学图像处理dicom
    分享一个PyTorch医学图像分割开源库 python医学
    昨天点击上方↑↑↑“OpenCV学堂”关注我来源:公众号 我爱计算机视觉授权  分享一位52CV粉丝Ellis开发的基于PyTorch的专注于医学图像分割的开源库,其支持模型丰富,方便易用。其可算为torchio的一个实例,作者将其综合起来,包含众多经典算法,实用性比
    03-08
  • 搞懂Transformer结构,看这篇PyTorch实现就够了
    搞懂Transformer结构,看这篇PyTorch实现就够了
    搞懂Transformer结构,看这篇PyTorch实现就够了昨天下面分享一篇实验室翻译的来自哈佛大学一篇关于Transformer的详细博文。"Attention is All You Need"[1] 一文中提出的Transformer网络结构最近引起了很多人的关注。Transformer不仅能够明显地提升翻译质量,
    03-08
  • 行人重识别(ReID) ——基于MGN-pytorch进行可视化展示
    行人重识别(ReID) ——基于MGN-pytorch进行可视
    https://github.com/seathiefwang/MGN-pytorch下载Market1501数据集:http://www.liangzheng.org/Project/project_reid.html模型训练,修改demo.sh,将 --datadir修改已下载的Market1501数据集地址,将修改CUDA_VISIBLE_DEVICES=2,3自己的GPU设备ID,将修改--
    03-08
  • Pytorch:通过pytorch实现逻辑回归
    Pytorch:通过pytorch实现逻辑回归
    logistic regression逻辑回归是线性的二分类模型(与线性回归的区别:线性回归是回归问题,而逻辑回归是线性回归+激活函数sigmoid=分类问题)模型表达式:f(x)称为sigmoid函数,也称为logistic函数,能将所有值映射到[0,1]区间,恰好符合概率分布,如下图所示
    03-08
  • 在 Windows 上为 Pytorch 和 Pytorch Geometric 构建 GPU 环境
    在 Windows 上为 Pytorch 和 Pytorch Geometric
    介绍这是我的第一篇文章。在研究机器学习时,我在使用 Pytorch 和 Pytorch Geometric 构建 GPU 环境时遇到了很多麻烦,所以我想留下我构建环境所做的工作。我希望这可以帮助任何处于类似情况的人。环境操作系统语GPUWindows 11 家庭 64 位蟒蛇 3.9.13RTX3060
    03-08
点击排行