原创
PyTorch入门学习
2024-3-6 00:16
389
2
2
分类:
机器人/ AI
PyTorch入门学习两个重要函数
package (pytorch) 有许多工具箱,探索这个工具箱,有两个工具:
dir():打开工具,看见工具箱某个分区的结构
help():如何使用这个工具
IN[1]: import torchIN[2]: torch.cuda.is_available()Out[3]: TrueIN[4]: dir(torch)IN[5]: dir(torch.cuda)IN[6]: dir(torch.cuda.is_available)#输出的有双下划线表示不能被更改,是一个具体的函数了#选择相应的环境
conda activate torchjupyter notebookpytorch如何加载数据?Dataset:可以用于提取数据,和它的真实的label,就能得到某一个具体的数据。
Dataloader:为网络提供不同的数据形式。
如何获取每一个数据及其label。
统计总共有多少个数据。
组织形式:训练集,测试集,label信息;也可以将label直接放置在图片的名称上。
from torch.utils.data import Datasethelp(Dataset)Dataset?? #表达更清晰代码实战
from torch.utils.data import Datasetimport PIL import Imageimport os #获取所有图片的地址#import cv2# conda install opencv-python; pip install opencv-pythonclass MyData(Dataset): def __init__(self,root_dir,label_dir):#提供全局变量 self.root_dir = root_dir #创建的时候进行赋值,就可以成为全局变量 self.label_dir = label_dir self.path = os.path.join(self.root_dir,self.label_dir) self.img_path = os.listdir(self.path) #函数中的变量无法传递到其他的函数中去,self实际上就是指向类自身首地址的指针,点成员实际就是通过偏移找到相应成员 def __getitem__(self,idx): #使用idx来获取图片,首先需要获取文件夹,然后才能得到文件夹下面所有的图片 img_name = self.img_path[idx] img_item_path = os.path.join(self.root_dir,self.label_dir,img_name) img = Image.open(img_item_path) label = self.label_dir return img,label def __len__(self): return len(self.img_path) root_dir = "dataset/train"ants_label_dir = "ants"bees_label_dir = "bees"ants_dataset = MyData(root_dir,ants_label_dir)bees_dataset = MyData(root_dir,bees_label_dir)train_dataset = ants_dataset + bees_datasetinput是照片蚂蚁,label是ants
In: from PIL import ImageIn: img_path="copy relative path" #注意路径在windows里面\需要加转义字符\\# /不需要转义,\需要转义为\\In: img = Image.open(img_path)In: img.sizeIn: img.show()In: dir_path = "dataset/train/ants"In: img_path_list = os.listdir(dir_path)In: import osIn: root_dir = "relative address"In: label_dir = "ants"In: path = os.path.join(root_dir,label_dir) #拼接地址In: img_path = os.listdir(path)In: ants_dataset[0]In: img, label = ants_dataset[0]In: img.show()In: img, label = ants_dataset[1]In: img.show()In: len(ants_dataset)In: len(bees_dataset)In: len(tarin_dataset)In: img,label = train_dataset[124]
作者: 拾肆, 来源:面包板社区
链接: https://mbb.eet-china.com/blog/uid-me-4074534.html
版权声明:本文为博主原创,未经本人允许,禁止转载!
文章评论(0条评论)
登录后参与讨论