pytorch自定义dataset pytorch自定义优化器

   2023-02-09 学习力420
核心提示:参考一个例子import torchfrom torch.utils import dataclass MyDataset(data.Dataset):def __init__(self):super(MyDataset, self).__init__()self.data = torch.randn(8,2)def __getitem__(self, index):return self.data[index], indexdef __len__(self):r

参考

一个例子

import torch
from torch.utils import data

class MyDataset(data.Dataset):
    def __init__(self):
        super(MyDataset, self).__init__()
        self.data = torch.randn(8,2)
    
    def __getitem__(self, index):
        return self.data[index], index
    
    def __len__(self):
        return self.data.size()[0]

data_set = MyDataset()
print(data_set.data)

输出
tensor([[-1.3907, -0.0916],
[-0.4626, -1.3323],
[ 1.4242, -2.1718],
[ 1.5850, 0.3320],
[-1.0804, 0.3884],
[ 0.6567, -0.1234],
[ 1.6721, -0.7327],
[-1.9595, -0.3512]])

data_loader = data.DataLoader(data_set, 
                              batch_size=4,
                              shuffle=False)
print(len(data_set))
for i, (number, labels) in enumerate(data_loader):
    print(number)

输出
8
tensor([[-1.3907, -0.0916],
[-0.4626, -1.3323],
[ 1.4242, -2.1718],
[ 1.5850, 0.3320]])
tensor([0, 1, 2, 3])
tensor([[-1.0804, 0.3884],
[ 0.6567, -0.1234],
[ 1.6721, -0.7327],
[-1.9595, -0.3512]])
tensor([4, 5, 6, 7])

 
反对 0举报 0
 

免责声明:本文仅代表作者个人观点,与乐学笔记(本网)无关。其原创性以及文中陈述文字和内容未经本站证实,对本文以及其中全部或者部分内容、文字的真实性、完整性、及时性本站不作任何保证或承诺,请读者仅作参考,并请自行核实相关内容。
    本网站有部分内容均转载自其它媒体,转载目的在于传递更多信息,并不代表本网赞同其观点和对其真实性负责,若因作品内容、知识产权、版权和其他问题,请及时提供相关证明等材料并与我们留言联系,本网站将在规定时间内给予删除等相关处理.

  • 基于pytorch框架的图像分类实践(CIFAR-10数据集)
    基于pytorch框架的图像分类实践(CIFAR-10数据集
    在学习pytorch的过程中我找到了关于图像分类的很浅显的一个教程上一次做的是pytorch的手写数字图片识别是灰度图片,这次是彩色图片的分类,觉得对于像我这样的刚刚开始入门pytorch的小白来说很有意义,今天写篇关于这个图像分类的博客.收获的知识1.torchvison
    03-08
  • 今天来捋一捋pytorch官方Faster R-CNN代码
    今天来捋一捋pytorch官方Faster R-CNN代码
    AI编辑:我是小将本文作者:白裳https://zhuanlan.zhihu.com/p/145842317本文已由原作者授权 目前 pytorch 已经在 torchvision 模块集成了 FasterRCNN 和 MaskRCNN 代码。考虑到帮助各位小伙伴理解模型细节问题,本文分析一下 FasterRCNN 代码,帮助新手理解
    03-08
  • 从零搭建Pytorch模型教程(三)搭建Transformer网络
    从零搭建Pytorch模型教程(三)搭建Transformer
    ​前言 本文介绍了Transformer的基本流程,分块的两种实现方式,Position Emebdding的几种实现方式,Encoder的实现方式,最后分类的两种方式,以及最重要的数据格式的介绍。 本文来自公众号CV技术指南的技术总结系列欢迎关注公众号CV技术指南,专注于计算机
    03-08
  • 几种网络LeNet、VGG Net、ResNet原理及PyTorch实现
    几种网络LeNet、VGG Net、ResNet原理及PyTorch
    LeNet比较经典,就从LeNet开始,其PyTorch实现比较简单,通过LeNet为基础引出下面的VGG-Net和ResNet。LeNetLeNet比较经典的一张图如下图LeNet-5共有7层,不包含输入,每层都包含可训练参数;每个层有多个Feature Map,每个FeatureMap通过一种卷积滤波器提取输
    03-08
  • Focal Loss 的Pytorch 实现以及实验
    Focal Loss 的Pytorch 实现以及实验
     Focal loss 是 文章 Focal Loss for Dense Object Detection 中提出对简单样本的进行decay的一种损失函数。是对标准的Cross Entropy Loss 的一种改进。 F L对于简单样本(p比较大)回应较小的loss。如论文中的图1, 在p=0.6时, 标准的CE然后又较大的loss
    03-08
  • Pytorch-基础入门之ANN pytorch零基础入门
    在这部分中来介绍下ANN的Pytorch,这里的ANN具有三个隐含层。这一块的话与上一篇逻辑斯蒂回归使用的是相同的数据集MNIST。第一部分:构造模型# Import Librariesimport torchimport torch.nn as nnfrom torch.autograd import Variable# Create ANN Modelclas
    03-08
  • 分享一个PyTorch医学图像分割开源库 python医学图像处理dicom
    分享一个PyTorch医学图像分割开源库 python医学
    昨天点击上方↑↑↑“OpenCV学堂”关注我来源:公众号 我爱计算机视觉授权  分享一位52CV粉丝Ellis开发的基于PyTorch的专注于医学图像分割的开源库,其支持模型丰富,方便易用。其可算为torchio的一个实例,作者将其综合起来,包含众多经典算法,实用性比
    03-08
  • 搞懂Transformer结构,看这篇PyTorch实现就够了
    搞懂Transformer结构,看这篇PyTorch实现就够了
    搞懂Transformer结构,看这篇PyTorch实现就够了昨天下面分享一篇实验室翻译的来自哈佛大学一篇关于Transformer的详细博文。"Attention is All You Need"[1] 一文中提出的Transformer网络结构最近引起了很多人的关注。Transformer不仅能够明显地提升翻译质量,
    03-08
  • 行人重识别(ReID) ——基于MGN-pytorch进行可视化展示
    行人重识别(ReID) ——基于MGN-pytorch进行可视
    https://github.com/seathiefwang/MGN-pytorch下载Market1501数据集:http://www.liangzheng.org/Project/project_reid.html模型训练,修改demo.sh,将 --datadir修改已下载的Market1501数据集地址,将修改CUDA_VISIBLE_DEVICES=2,3自己的GPU设备ID,将修改--
    03-08
  • Pytorch:通过pytorch实现逻辑回归
    Pytorch:通过pytorch实现逻辑回归
    logistic regression逻辑回归是线性的二分类模型(与线性回归的区别:线性回归是回归问题,而逻辑回归是线性回归+激活函数sigmoid=分类问题)模型表达式:f(x)称为sigmoid函数,也称为logistic函数,能将所有值映射到[0,1]区间,恰好符合概率分布,如下图所示
    03-08
点击排行