YOLOv2
YOLOv2的改进结果后的效果/pic1.png)
/pic2.png)
YOLOv1的问题
- 回归得到的box的精准不够高
- 召回率不够
一般遇到这个类问题的解决思路是把网络加深加宽,而本文通过优化网络学习在准确率不降的情况下提升精度和召回率.
Batch Normalization
CNN在训练过程中网络每层输入的分布一直在改变, 会使训练过程难度加大,但可以通过normalize每层的输入解决这个问题。新的YOLO网络在每一个卷积层后添加batch normalization,通过这一方法,mAP获得了2%的提升。batch normalization 也有助于规范化模型,可以在舍弃dropout优化后依然不会过拟合。/pic8.png)
High Resolution classifier
目前的目标检测方法中,基本上都会使用ImageNet预训练过的模型(classifier)来提取特征,如果用的是AlexNet网络,那么输入图片会被resize到不足256 * 256,导致分辨率不够高,给检测带来困难。为此,新的YOLO网络把分辨率直接提升到了448×448,这也意味之原有的网络模型必须进行某种调整以适应新的分辨率输入。
对于YOLOv2,作者首先对分类网络(自定义的darknet)进行了fine tune,分辨率改成448×448,在ImageNet数据集上训练10轮(10 epochs),训练后的网络就可以适应高分辨率的输入了。然后,作者对检测网络部分(也就是后半部分)也进行fine tune。这样通过提升输入的分辨率,mAP获得了4%的提升。
/pic9.png)
- 训练输入为224×224图片的分类网络
- 进行微调,将输入图片的大小改为448×448,仍为分类网络
- 然后微调detection部分,输入仍为448×448
Convolutional With Anchor Boxes
之前的YOLO利用全连接层的数据完成边框的预测,导致丢失较多的空间信息,定位不准。作者在这一版本中借鉴了Faster R-CNN中的anchor思想,回顾一下,anchor是RNP网络中的一个关键步骤,说的是在卷积特征图上进行滑窗操作,每一个中心可以预测9种不同大小的建议框。
首先将原网络的全连接层和最后一个pooling层去掉(网络仅采用卷积层和池化层),使得最后的卷积层可以有更高分辨率的特征;然后缩减网络,用416×416大小的输入代替原来448×448.这样做的原因是希望得到的特征图都有奇数大小的宽和高,奇数大小的宽和高会使得每个特征图在划分cell的时候就只有一个center cell(比如可以划分成7×7或9×9个cell,center cell只有一个,如果划分成8×8或10×10的,center cell有4个)。为什么希望只有一个center cell呢?因为大的object一般会占据图像的中心,所以希望用一个center cell去预测,而不是4个center cell去预测.网络最终将416×416的输入变成13×13大小的feature map输出,也就是缩小比例为32.
我们知道原来的YOLO算法将输入图像分成7×7的网格,每个网格预测两个bounding box,因此一共只有98个box,但是在YOLOv2通过引入anchor boxes,预测的box数量超过了1000(以输出feature map大小为13×13为例,每个grid cell有9个anchor box的话,一共就是13×13×9=1521个,当然由Dimension clusters可知,最终每个grid cell选择5个anchor box).顺便提一下在Faster RCNN在输入大小为1000×600时的boxes数量大概是6000,在SSD300中boxes数量是8732.显然增加box数量是为了提高oject的定位准确率.
Dimension Clusters
我们知道在Faster-RCNN中anchor box的大小和比例是按经验设定的,然后网络会在训练过程中调整anchor box的尺寸.但是如果一开始就能选择到合适尺寸的anchor box,那肯定可以帮助网络越好地预测detection.所以作者使用k-means的方式对训练集的bounding boxes做聚类,试图找到合适的anchor box.另外作者发现如果采用标准的k-means(即用欧氏距离来衡量差异),在box的尺寸较大的时候其误差也更大,而我们希望的是误差和box的尺寸没有太大关系.所以通过IOU定义了如下的距离函数,使得误差和box的大小无关:
Faster RCNN是手选9个 anchor box的宽高的大小,而YOLOv2是通过训练集训练出来适合的宽高大小。
通过训练集中各个框的宽高来聚类,得到k=5个标准的宽高.
YOLOv3也有用Dimension clusters,但要注意:yolov2里的值是相对于特征图的,值很小基本都小于13;但yolov3里的值是相对于原图来说的,相对比较大。
Direct Location prediction
作者在引入anchor box的时候遇到的第二个问题:模型不稳定,尤其是在训练刚开始的时候.作者认为这种不稳定主要来自预测box的(x,y)值,从而在YOLOv2中改为用sigmoid函数预测offset.这里的$x_a和y_a$是anchor的坐标,$w_a$和$h_a$是anchor的size,x和y是坐标的预测值,$t_x和t_y$是偏移量.
在这里作者并没有采用直接预测offset的方法,还是沿用了YOLO算法中直接预测相当于grid cell的坐标位置的方式。
前面提到网络在最后一个卷积层输出13×13大小的feature map,然后每个cell预测5个bounding box,然后每个bounding box预测5个值:$t_x,t_y,t_w,t_h和t_o$(这里的$t_o$类似YOLOv1中的confidence).看下图,$t_x$和$t_y$经过sigmoid函数处理后范围在0到1之间,这里的归一化处理也使得模型训练更加稳定;$c_x$和$c_y$表示cell和图像左上角的横纵距离;$p_w$和$p_h$表示bounding box的宽高,这样$b_x$和$b_y$就是$c_x$和$c_y$这个cell附近的anchor来预测$t_x$和$t_y$得到的结果.
上面公式可看下图来理解,首先$c_x$和$c_y$,表示grid cell与图像左上角的横纵坐标距离,黑色虚线框是bounding box,蓝色矩形框就是预测结果./pic3.png)
Fine-Grained features
这里主要是添加了一个层:passthrough layer.这个层的作用就是将前面一层的26×26的feature map和本层的13×13的feature map进行连接,有点像ResNet.这样做的原因在于虽然13×13的feature map对于预测大的object已经足够了,但是对于预测小的object就不一定有效.也容易理解,越小的object,经过层层卷积和pooling,可能到最后都不见了,所以通过合并前一层的size大一点的feature map可以有效检测小的object.
Multi-Scale Training
采用不同尺寸的图片训练,提高鲁棒性
为了让YOLOv2模型更加robust,作者引入了Multi-Scale Training,简单讲就是在训练时输入图像的size是动态变化的,注意这一步是在检测数据集上fine tune时候采用的,不要跟前面在Imagenet数据集上的两步预训练分类模型混淆,本文细节确实很多.具体,在训练网络时,每训练10个epoch,网络就会随机选择另一种size的输入.那么输入图像的size的变化范围要怎么定?前面我们知道本文网络本来的输入是416×416,最后会输出13×13的feature map,也就是说downsample的factor是32,因此作者采用32的倍数作为输入的size,具体来说文中作者采用{320,352,…,608}的输入尺寸.(例如输入图片的大小为608×608,对应的特征图大小为19×19,downsample factor是32)
这种网络训练方式使得相同网络可以对不同分辨率的图像做detection.虽然在输入size较大时,训练速度较慢,但同时在输入size较小时,训练速度较快,而multi-scale training又可以提高准确率,因此算是准确率和速度都取得一个不错的平衡.
基本思路
下列是模型的大致结构:/pic4.png)
其主要由两个部分构成:
- 神经网络:将图片计算为一个13×13×125的向量,该向量包含了预测的物品位置和类别信息
- 检测器:将神经网络输出的向量进行”解码”操作,输出物品的分类和位置.
神经网络部分
YOLOv2的神经网络部分使用了一个带跳层的神经网络(在Fine-Grained features里有介绍),具体如下所示:/pic5.png)
检测器部分
YOLOv2使用Anchor Box的方法,神经网络输出的向量尺寸是13×13×125,其中13×13是将图片划分为13行和13列共169个cell,每个cell有125数据.对于每个cell的125个数据,分解为125=5×(5+20),即每个cell包括5个anchor box,每个anchor cell包括25个数据,分别为物品存在置信度、物品中心位置(x,y),物品尺寸(w,h)和类别信息(20个)。如下图所示:/pic6.png)
对于每个cell包括5个anchor box信息,每个anchor box包括25个数据,分别:
- 为是否有物品(1个)
- 物品位置(4个)
- 物品类别(20个)
 其中是否有物品的标记$conf_{ijk}$比较容易理解,表示位于i,j cell的第k个anchor box中有物品的置信度.20个物品种类向量也比较好理解,哪一个数据最大即物品为对应的类别.
对于物品位置的四个数据分别为$x_{ijk},y_{ijk},w_{ijk},h_{ijk}$,与物品位置中心点和尺寸的关系为:
其中,$b_x,b_y$为物品中心点的实际坐标,$b_w,b_h$为物品的尺寸(长宽).$c_x,c_y$为该cell(x行y列)距离图片左上角的像素数,f的含义推测为将范围0~1的输入值缩放到0~cell长度.$p_w$和$p_h$为该anchor box的预设尺寸.如下图所示:/pic7.png)
每个cell包括5个anchor box,这5个anchor box有不同的预设尺寸,该预设尺寸可以手动指定也可以在训练集上训练获得。在YOLOv2中,预设尺寸是通过在测试集上进行聚类获得的.
模型训练
神经网络部分基于Darknet-19,该模型的训练部分分为两个部分:预训练和训练部分
- 预训练:预训练是在ImageNet上按分类的方式进行预训练160轮,使用SGD优化方法,初始学习率0.1,每次下降4倍,到0.0005时终止.除了训练224×224尺寸的图像外,还是用448×448尺寸的图片
- 训练:去除DarkNet的最后一个卷积层,并将网络结构修改为YOLOv2的网络,在VOC数据集上进行训练,训练使用的代价函数是MSE代价函数.
- YOLO2损失函数:有4类loss,它们的weight不同,分别是object、noobject、class、coord.总体loss是这4个部分的平方和.
代码实现:
首先要明确:
YOLO是将图片输入给网络以后直接回归出框的坐标信息以及类别信息.所以对于网络我们只需要明确它的输入大小是什么输出大小是什么。
前面提到网络是一个Darknet,训练分为两步:1.预训练:在ImageNet上按分类方式训练;(先对224×224的图片训练分类,再对448×448的图片训练分类).2.微调:修改原来的网络以适应现在的任务.前面提到过,为了使图片能分割成奇数×奇数的块,最后训练时图片的输入大小为416,缩小32倍为13.并且这个输入大小还会变,因为需要Multi-Scale Training.这里只讨论分块为13×13的形式。而输出方面,由于YOLOv2借鉴了Faster RCNN,使用了anchor box的方法。为每个grid cell预测5个anchor box,而每个anchor box给出2个位置信息(中心),2个框大小信息(w和h),1个置信度信息(是否有物体),20个类别信息.所以输出形式为13×13×(5×(2+2+1+20))=13×13×125.输出的特征图的形式是宽高为13,channel为125.在pytorch中为(1,125,13,13)1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43class Darknet(nn.Module):
    cfg1 = [32, 'M', 64, 'M', 128, (64,1), 128, 'M', 256, (128,1), 256, 'M', 512, (256,1), 512, (256,1), 512]  # conv1 - conv13
    cfg2 = ['M', 1024, (512,1), 1024, (512,1), 1024]  # conv14 - conv18
    def __init__(self):
        super(Darknet,self).__init__()
        self.layer1=self.make_layers(self.cfg1,in_planes=3)
        self.layer2=self.make_layers(self.cfg2,in_planes=512)
        ##Add new layers
        self.conv19=nn.Conv2d(1024,1024,kernel_size=3,stride=1,padding=1)
        self.bn19=nn.BatchNorm2d(1024)
        self.conv20=nn.Conv2d(1024,1024,kernel_size=3,stride=1,padding=1)
        self.bn20=nn.BatchNorm2d(1024)
        self.conv21=nn.Conv2d(1024,1024,kernel_size=3,stride=1,padding=1)
        self.bn21=nn.BatchNorm2d(1024)
        self.conv22=nn.Conv2d(1024,5*(5+20),kernel_size=1,stride=1,padding=0)
    def make_layers(self,cfg,in_planes):
        layers=[]
        for x in cfg:
            if x=='M':
                layers+=[nn.MaxPool2d(kernel_size=2,stride=2,ceil_mode=True)]
            else:
                out_planes=x[0] if isinstance(x,tuple) else x
                ksize=x[1] if isinstance(x,tuple) else 3
                layers+=[nn.Conv2d(in_planes,out_planes,kernel_size=ksize,padding=(ksize-1)//2),
                         nn.BatchNorm2d(out_planes),
                         nn.LeakyReLU(0.1,True)]
                in_planes=out_planes
        return nn.Sequential(*layers)
    def forward(self,x):
        out=self.layer1(x)
        out=self.layer2(out)
        out=F.leaky_relu(self.bn19(self.conv19(out)),0.1)
        out=F.leaky_relu(self.bn20(self.conv20(out)),0.1)
        out=F.leaky_relu(self.bn21(self.conv21(out)),0.1)
        out=self.conv22(out)
        return out
可以验证一下:1
2
3net=Darknet()
y=net(Variable(torch.randn(1,3,416,416)))
print(y.size())
接下来将主要介绍如何处理:数据里的信息.以及如何处理网络最后的YOLOLoss.
box的iou计算在目标检测中是常见的:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24def box_iou(box1,box2):
    '''
    box1,box2 are as:[xmin,ymin,xmax,ymax]
    '''
    N=box1.size(0)
    M=box2.size(0)
    ##求框左上角最大值的点(即相交框的左上角)
    lt=torch.max(box1[:,:2].unsqueeze(1).expand(N,M,2),#[N,2]->[N,1,2]->[N,M,2]
                 box2[:,:2].unsqueeze(0).expand(N,M,2) )#[M,2]->[1,M,2]->[N,M,2]
    ##求框右下角最小值的点(即相交框的右下角)
    rb=torch.min(box1[:,2:].unsqueeze(1).expand(N,M,2),
                 box2[:,2:].unsqueeze(0).expand(N,M,2))
    wh=(rb-lt).clamp(min=0)##clamp是阶段,两点相减不能有负数
    inter=wh[:,:,0]*wh[:,:,1]##计算重合框的面积 [N,M]
    area1=(box1[:,2]-box1[:,0])*(box1[:,3]-box1[:,1])##计算box1中各个框的面积  [N]
    area2=(box2[:,2]-box[:,0])*(box2[:,3]-box2[:,1])##计算box2中各个框的面积 [M]
    area1=area1.unsqueeze(1).expand_as(inter)## 扩展成[N×M]
    area2=area2.unsqueeze(0).expand_as(inter)##扩展成[N×M]
    iou=inter/(area1+area2-inter)##得到[N×M]box1中各个框与box2中各个框的iou
    return iou
还有一个重要的函数是:box_nms ,非极大抑制(Non-Maximum Suppression)
这个我们在RCNN中有看到过,目的是去除冗余的检测框,保留最好的一个.
其原理是:对于Bounding Box的一个列表B及其对应的置信度S,采用下面方式:选择具有最大score的检测框M,将其从B集合中移除并加入到最终的检测结果D中。通常将B中剩余检测框中与M的IOU大于阈值threshold的框从B中移除.重复这个过程,直到B为空。
| 1 | def box_nms(bboxes,scores,threshold=0.5,mode='union'): | 
接下来就是处理我们的数据集,之前提到我们输入的图片,这显然不需要我们处理,顶多改变一下图片大小或者做一些数据增强。关键的是我们要处理一下targets,这里我们称这步为编码location和class_label
| 1 | ''' | 
以上就是我们要对训练集的目标框做的处理,得到我们需要的targets
下面是我们处理的dataset1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117class ListDataset(data.Dataset):
    input_sizes=[320+32*i for i in range(10)]
    def __init__(self,root,list_file,train,transform):
        '''
        root:(directory) to images
        list_file:(str)path to index file
        transform:([transforms])image transforms
        '''
        self.root=root
        self.train=train
        self.transform=transform
        self.fnames=[]
        self.boxes=[]
        self.labels=[]
        self.data_encoder=DataEncoder()
        ##从file中得到各张图的目标数据
        with open(list_file) as f:
            line=f.readlines()
            self.num_samples=len(lines)
        for line in lines:
            splited=line.strip().split()
            self.fnames.append(splited[0])
            num_boxes=(len(splited)-3)//5
            box=[]
            label=[]
            for i in range(num_boxes):
                xmin=splited[3+5*i]
                ymin=splited[4+5*i]
                xmax=splited[5+5*i]
                ymax=splited[6+5*i]
                c=splited[7+5*i]
                box.append([float(xmin),float(ymin),float(xmax),float(ymax)])##左上角,右下角的点
                label.append(int(c))##类别
            self.boxes.append(torch.Tensor(box))
            self.labels.append(torch.LongTensor(label))
        def __getitem__(self,idx):
            '''
            load a image,and encode its bbox location and class labels
            '''
            fname=self.fnames[idx]
            img=Image.open(os.path.join(self.root,fname))
            boxes=self.boxes[idx].clone()
            labels=self.labels[idx]
            if self.train:
                img,boxes=self.random_flip(img,boxes)
                img,boxes,labels=self.random_crop(img,boxes,labels)
            w,h=img.size
            boxes/=torch.Tensor([w,h,w,h]).expand_as(boxes)##boxes里的值都是[0,1],这里必须除的原图的size而不能是input_size
            input_size=416
            img=img.resize((input_size,input_size))
            img=self.transform(img)
            #encode data
            loc_targets,cls_targets,box_targets=self.data_encoder.encode(boxes,labels,input_size)
            return img,loc_targets,cls_targets,box_targets
        def random_flip(self,img,boxes):
            if random.random()<0.5:
                img=img.transpose(Image.FLIP_LEFT_RIGHT)
                w=img.width
                xmin=w-boxes[:,2]
                xmax=w-boxes[:,0]
                boxes[:,0]=xmin
                boxes[:,2]=xmax
            return img,boxes
        def random_crop(self,img,boxes,labels):
            while True:
                min_iou=random.choice([None,0.1,0.3,0.5,0.7,0.9])
                if min_iou is None:
                    return img,boxes,labels
                for _ in range(100):
                    w=random.randrange(int(0.1*imw),imw)
                    h=random.randrange(int(0.1*imh),imh)
                    if h>2*w or w>2*h:
                        continue
                    x=random.randrange(imw-w)
                    y=random.randrange(imh-h)
                    roi=torch.Tensor([[x,y,x+w,y+h]])
                    center=(boxes[:,:2]+boxes[:,2:])/2
                    roi2=roi.appemd(len(center),4)
                    mask=(center>roi2[:,:2])&(center<roi2[:,2:])
                    mask=mask[:,0]&mask[:,1]
                    if not mask.any():
                        continue
                    selected_boxes=boxes.index_select(0,mask.nonzero().squeeze(1))
                    ious=box_iou(selected_boxes,roi)
                    if ious.min()<min_iou:
                        continue
                    img=img.crop((x,y,x+w,y+h))
                    selected_boxes[:,0].add_(-x).clamp_(min=0,max=w)
                    selected_boxes[:,1].add_(-y).clamp_(min=0,max=h)
                    selected_boxes[:,2].add_(-x).clamp_(min=0,max=w)
                    selected_boxes[:,3].add_(-y).clamp_(min=0,max=h)
                    return img,selected_boxes,labels[mask]
                def __len__(self):
                    return self.num_samples
最后我们需要重新设计一个YOLOLoss,使其能更好的从训练集中学习到参数1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59class YOLOLoss(nn.Module):
    def __init__(self):
        super(YOLOLoss,self).__init__()
    def decode_loc(self,loc_preds):
        anchors=[(1.3221,1.73145),(3.19275,4.00944),(5.05587,8.09892),(9.47112,4.84053),(11.2364,10.0071)]
        N,_,_,fmsize,_=loc_preds.size()
        loc_xy=loc_preds[:,:,:2,:,:]
        grid_xy=meshgrid(fmsize,swap_dims=True).view(fmsize,fmsize,2).permute(2,0,1)
        grid_xy=Variable(grid_xy.cuda())
        box_xy=loc_xy.sigmoid()+grid_xy.expand_as(loc_xy)
        loc_wh=loc_preds[:,:,2:4,:,:]
        anchor_wh=torch.Tensor(anchors).view(1,5,2,1,1).expand_as(loc_wh)
        anchor_wh=Variable(anchor_wh.cuda())
        box_wh=anchor_wh*loc_wh.exp()
        box_preds=torch.cat([box_xy-box_wh/2,box_xy+box_wh/2],2)
        return box_preds
    def forward(self,preds,loc_targets,cls_targets,box_targets):
        batch_size,_,fmsize,_=preds.size()
        preds=preds.view(batch_size,5,4+1+20,fmsize,fmsize)##[batch_siz,5,25,fmsize,fmsize]
        xy=preds[:,:,:2,:,:].sigmoid()##中心点
        wh=preds[:,:,2:4,:,:].exp()##长宽
        loc_preds=torch.cat([xy,wh],2)##(batchsize,5,4,fmsize,fmsize)
        pos=cls_targets.max(2)[0].squeeze()>0
        num_pos=pos.data.long().sum()
        mask=pos.unsqueeze(2).expand_as(loc_preds)
        loc_loss=F.smooth_l1_loss(loc_preds[mask],loc_targets[mask],size_average=False)
        #iou_loss
        iou_preds=preds[:,:,4,:,:].sigmoid()#置信度
        iou_targets=Variable(torch.zeros(iou_preds.size()).cuda())
        box_preds=self.decode_loc(preds[:,:,:4,:,:])
        box_preds=box_preds.permute(0,1,3,4,2).contiguous().view(batch_size,-1,4)
        for i in range(batch_size):
            box_pred=box_preds[i]
            box_target=box_targets[i]
            iou_target=box_iou(box_pred,box_target)
            iou_targets[i]=iou_target.max(1)[0].view(5,fmsize,fmsize)
        mask=Variable(torch.ones(iou_preds.size()).cuda())*0.1
        mask[pos]=1
        iou_loss=F.smooth_l1_loss(iou_preds*mask,iou_targets*mask,size_average=False)
        ##cls_loss
        cls_preds=preds[:,:,5:,:,:]
        cls_preds=cls_preds.permute(0,1,3,4,2).contiguous().view(-1,20)
        cls_preds=F.softmax(cls_preds)
        cls_preds=cls_preds.view(batch_size,5,fmsize,fmsize,20).permute(0,1,4,2,3)
        pos=cls_targets>0
        cls_loss=F.smooth_l1_loss(cls_preds[pos],cls_targets[pos],size_average=False)
        return (loc_loss+iou_loss+cls_loss)/num_pos
最后一步就是将,上述这几步组装在一起进行训练:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90import torch.optim as optim
best_loss=float('inf')#best test loss
start_epoch=0
lr=0.0001
root='/search/data/user/liukuang/data/VOC2012_trainval_test_images'
list_file='/search/data/user/liukuang/data/VOC2012_trainval_test_images'
root_test='/search/data/user/liukuang/data/VOC2012_trainval_test_images'
list_file_test='./voc_data/voc12_test.txt'
def collate_fn(batch):
    return torch.stack([x[0] for x in batch]),\
            torch.stack([x[1] for x in batch]),\
            torch.stack([x[2] for x in batch]),\
            [x[3] for x in batch]
transform=transforms.Compose([transforms.ToTensor()])
trainset=ListDataset(root=root,list_file=list_file,train=True,transform=transform)
trainloader=torch.utils.data.DataLoader(trainset,batch_size=32,shuffle=True,num_workers=8,collate_fn=collate_fn)
testset=ListDataset(root=root_test,list_file=list_file_test,train=False,transform=transform)
testloader=torch.utils.data.DataLoader(testset,batch_size=32,shuffle=False,num_workers=8,collate_fn=collate_fn)
net=Darknet()
net=torch.nn.DataParallel(net,device_ids=range(torch.cuda.device_count()))
net.cuda()
cudnn.benchmark=True
criterion=YOLOLoss()
optimizer=optim.SGD(net.parameters(),lr=lr,momentum=0.9,weight_decay=1e-4)
def train(epoch):
    print('\n Epoch:%d'%epoch)
    net.train()
    trian_loss=0
    for batch_idx,(images,loc_targets,cls_targets,box_targets) in enumerate(trainloader):
        images=Variable(images.cuda())
        loc_targets=Variable(loc_targets.cuda())
        cls_targets=Variable(cls_targets.cuda())
        box_targets=[Variable(x.cuda()) for x in box_targets]
        optimizer.zero_grad()
        outputs=net(images)
        loss=criterion(outputs,loc_targets,cls_targets,box_targets)
        loss.backward()
        optimizer.step()
        train_loss+=loss.data[0]
        print('%.3f %.3f'%(loss.data[0],train_loss/(batch_idx+1)))
def test(epoch):
    print('\n Test')
    net.eval()
    test_loss=0
    for batch_idx,(images,loc_targets,cls_targets,box_targets) in enumerate(testloader):
        images=Variable(images.cuda())
        loc_targets=Variable(loc_targets.cuda())
        cls_targets=Variable(cls_targets.cuda())
        box_targets=[Variable(x.cuda()) for x in box_targets]
        outputs=net(images)
        loss=criterion(outputs,loc_targets,cls_targets,box_targets)
        test_loss+=loss.data[0]
        print('%.3f %.3f'%(loss.data[0],test_loss/(batch_idx+1)))
    global best_loss
    test_loss/=len(testloader)
    if(test_loss<best_loss):
        print("Saving...")
        state={
            'net':net.module.state_dict(),
            'loss':test_loss,
            'epoch':epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state,'./checkpoint/ckpt/pth')
        best_loss=test_loss
for epoch in range(start_epoch,start_epoch+200):
    train(epoch)
    test(epoch)
