原创 PyTorch入门学习

2024-3-6 00:16 389 2 2 分类: 机器人/ AI
PyTorch入门学习

两个重要函数

package (pytorch) 有许多工具箱,探索这个工具箱,有两个工具: dir():打开工具,看见工具箱某个分区的结构 help():如何使用这个工具

IN[1]: import torch
IN[2]: torch.cuda.is_available()
Out[3]: True
IN[4]: dir(torch)
IN[5]: dir(torch.cuda)
IN[6]: dir(torch.cuda.is_available)
#输出的有双下划线表示不能被更改,是一个具体的函数了

#选择相应的环境
conda activate torch

jupyter notebook
pytorch如何加载数据?

Dataset:可以用于提取数据,和它的真实的label,就能得到某一个具体的数据。 Dataloader:为网络提供不同的数据形式。 如何获取每一个数据及其label。 统计总共有多少个数据。

组织形式:训练集,测试集,label信息;也可以将label直接放置在图片的名称上。

from torch.utils.data import Dataset
help(Dataset)
Dataset?? #表达更清晰

代码实战

from torch.utils.data import Dataset
import PIL import Image
import os #获取所有图片的地址
#import cv2
# conda install opencv-python; pip install opencv-python

class MyData(Dataset):
def __init__(self,root_dir,label_dir):#提供全局变量
       self.root_dir = root_dir #创建的时候进行赋值,就可以成为全局变量
       self.label_dir = label_dir
       self.path = os.path.join(self.root_dir,self.label_dir)
       self.img_path = os.listdir(self.path)
        #函数中的变量无法传递到其他的函数中去,self实际上就是指向类自身首地址的指针,点成员实际就是通过偏移找到相应成员

def __getitem__(self,idx): #使用idx来获取图片,首先需要获取文件夹,然后才能得到文件夹下面所有的图片
       img_name = self.img_path[idx]
       img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
       img = Image.open(img_item_path)
       label = self.label_dir
       return img,label
 
def __len__(self):
       return len(self.img_path)

   
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)

train_dataset = ants_dataset + bees_dataset

input是照片蚂蚁,label是ants

In: from PIL import Image
In: img_path="copy relative path" #注意路径在windows里面\需要加转义字符\\
# /不需要转义,\需要转义为\\
In: img = Image.open(img_path)
In: img.size
In: img.show()
In: dir_path = "dataset/train/ants"
In: img_path_list = os.listdir(dir_path)

In: import os
In: root_dir = "relative address"
In: label_dir = "ants"
In: path = os.path.join(root_dir,label_dir)  #拼接地址

In: img_path = os.listdir(path)

In: ants_dataset[0]
In: img, label = ants_dataset[0]
In: img.show()
In: img, label = ants_dataset[1]
In: img.show()

In: len(ants_dataset)
In: len(bees_dataset)
In: len(tarin_dataset)

In: img,label = train_dataset[124]

作者: 拾肆, 来源:面包板社区

链接: https://mbb.eet-china.com/blog/uid-me-4074534.html

版权声明:本文为博主原创,未经本人允许,禁止转载!

PARTNER CONTENT

文章评论0条评论)

登录后参与讨论
EE直播间
更多
我要评论
0
2
关闭 站长推荐上一条 /3 下一条