欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

[Tensorflow] Batch Normalization实现

程序员文章站 2022-07-16 19:13:35
...

bn的优势:

(1)更大的学习率(传统方法太大的learning rate容易导致梯度explode/vanish,或者get stuck in poor local)
(2)不再需要dropout

(3)less careful about initialization

但是BN不仅仅加BN层,还要修改以下的东西才能更快:
(1)learning rate 赋予更大的初值,且下降得更快。(比如将learning rate从0.0015扩大5倍到0.0075,下降快6倍)
(2)Remove Droupout
(3)Reduce L2 weight decay。(比如每次除5)
(4)Remove LRN
(5、6)其他。。看论文
ResNet 有用到BN,其在CIFAR-10网络中参数为:
20,34,44,56层使用:
learning rate=0.1,在32k和48k iterations时/10。
l2 weght decay=1e-4
110层使用:0.01learning rat用于warm up training,直到training error小于80%。

一、tf.nn.batch_normalization

Tensorflow 提供了Batch Normalization的API。但是,这个API很灵活,灵活的后果就是我们需要自己去定义所有的参数。
(比如,提供给此API的Tensor,居然需要我们自己去计算mean和variance)

tf.nn.batch_normalization(
  x,                #Tensor,对它执行BN操作
  mean,             #Tensor,一般为x的平均数,float32。
  variance,         #Tensor,一般为x的方差,float32。
  offset,           #Tensor,beta值,BN的shift操作。一般初始为0
  scale,            #Tensor,gamma值,BN的scale操作。一般初始为1
  variance_epsilon, #float。小的实数防止除0出现。
  name=None
)
"""
返回值(Tensor): 
  y= (x-mean)/sqrt(variance^2+variance_epsilon)*scale+offset。
  
但是mean和variance需要自己提前计算,
而tensorflow又提供了另一个API来计算mean和variance。(当然我们也可以自己瞎搞一个)
"""
这个API完全按照论文的思路设计,且更加灵活(比如mean和variance可以设置为其他值而不是x的均值和方差,beta和gamma也是如此)。
(见下图):
[Tensorflow] Batch Normalization实现

二、Tensor平均数和方差计算tf.nn.moments

  由于上述的API需要手动计算mean和variance,所以就用到了这个API。
tf.nn.moments(
  x,              #Tensor,要计算mean和variance的变量
  axes,           #要处理的维度。BN一般就是所有的维度。即[d for d in range(len(x.get_shape())]
  shift=None, 
  name=None, 
  keep_dims=False
)


三、例子

import tensorflow as tf
sess=tf.Session()
x=tf.constant([[1,5],[10,100]],dtype=tf.float32)
#维度
axes=[d for d in range(len(x.get_shape()))]
#beta gamma参数
beta= tf.get_variable("beta",shape=[],initializer=tf.constant_initializer(0.0))
gamma=tf.get_variable("gamma",shape=[],initializer=tf.constant_initializer(1.0))
sess.run(tf.global_variables_initializer())
#计算mean和variance,并执行BN操作
x_mean,x_variance=tf.nn.moments(x,axes)
y=tf.nn.batch_normalization(x,x_mean,x_variance,beta,gamma,1e-10,"bn")
#查看最终值
y_mean,y_variance=tf.nn.moments(y,axes)
x_val,xm_val,xv_val,y_val,ym_val,yv_val=sess.run([x,x_mean,x_variance,y,y_mean,y_variance])
print("*********执行BN前的Variable x:************")
print("x=%s\n x mean=%s\n x variance=%s" %(x_val,xm_val,xv_val))
print("*********执行BN后的Variable y:************")
print("y=%s \n y mean=%s\n y variance=%s" %(y_val,ym_val,yv_val))

执行结果为:
*********执行BN前的Variable x:************
x=[[   1.    5.]
 [  10.  100.]]
 x mean=29.0
 x variance=1690.5
*********执行BN后的Variable y:************
y=[[-0.68100518 -0.58371872]
 [-0.46211064  1.72683454]] 
 y mean=0.0
 y variance=1.0

可知道x经过BN处理后得到y,y的均值为0,方差变成1了(beta为0,gamma为1时)。
这里我们可以修改下beta和gamma的初始值,则y的平均值会变成beta,方差会变成gamma^2。

四、BN层放置顺序

   BN网络中,一个卷积层或全连接层中,对于输入x,有3步中间操作:BN操作、weight操作、ReLu操作。这三种操作的顺序该怎么排列。
   原论文的说法是:在Any layer previously received x as input, now received BN(x),但一个卷积层中的子层呢?
   对于2*con16 =》 2*conv32=》2*conv64=》fc-10 在MNIST中试了下三种顺序:
  (1) x -> bn -> weight -> relu
  (2) x -> bn -> relu -> weight
  (3) x -> weight ->bn -> relu
   最后发现效果都挺好的,可能是这个数据集太简单了,有待以后继续测试。。。
   不过在Resnet 1k网络中,第2种方法比第3种效果更好(在有shortcut的情况)。
论文地址:https://arxiv.org/pdf/1603.05027.pdf

五、BN在Mnist效果对比

由于Mnist太简单,正常CNN网络加不加BN层效果不明显。
所以我们需要给网络模型增加训练难度:把ReLu替换成Sigmoid。
(使用Sigmoid会让训练无比的慢,起码慢了百八十倍了~我一开始还以为网络出问题了。。ReLu真的强大!)
其他参数一致:网络为768*100*100*100*10的全连接模型,学习率为1e-4,momentum=0.9,L2_WEGHT_DECAY=1e-4,batch_sz为50,训练10个epoch。

无BN层训练结果:

[step100] accuracy=0.1 loss=116.691
[step200] accuracy=0.16 loss=114.826
[step300] accuracy=0.1 loss=115.051
[step400] accuracy=0.1 loss=117.023
[step500] accuracy=0.08 loss=115.734
[step600] accuracy=0.14 loss=114.13
[step700] accuracy=0.1 loss=115.985
[step800] accuracy=0.1 loss=115.7
[step900] accuracy=0.02 loss=117.614
[step1000] accuracy=0.1 loss=115.558
[*]Test Result=0.0892000000738 at epoch0
[step100] accuracy=0.08 loss=114.817
[step200] accuracy=0.22 loss=113.812
[step300] accuracy=0.1 loss=115.722
[step400] accuracy=0.04 loss=116.21
[step500] accuracy=0.14 loss=115.215
[step600] accuracy=0.08 loss=115.071
[step700] accuracy=0.14 loss=115.076
[step800] accuracy=0.06 loss=116.63
[step900] accuracy=0.12 loss=114.81
[step1000] accuracy=0.08 loss=115.669
[*]Test Result=0.100900000408 at epoch1
[step100] accuracy=0.1 loss=115.425
[step200] accuracy=0.1 loss=115.394
[step300] accuracy=0.08 loss=115.214
[step400] accuracy=0.04 loss=114.856
[step500] accuracy=0.08 loss=117.108
[step600] accuracy=0.14 loss=113.223
[step700] accuracy=0.08 loss=115.142
[step800] accuracy=0.16 loss=114.448
[step900] accuracy=0.1 loss=114.995
[step1000] accuracy=0.18 loss=115.651
[*]Test Result=0.113499999568 at epoch2
[step100] accuracy=0.12 loss=114.254
[step200] accuracy=0.08 loss=116.074
[step300] accuracy=0.2 loss=113.781
[step400] accuracy=0.08 loss=115.302
[step500] accuracy=0.06 loss=115.785
[step600] accuracy=0.08 loss=116.462
[step700] accuracy=0.08 loss=114.897
[step800] accuracy=0.14 loss=116.592
[step900] accuracy=0.1 loss=116.425
[step1000] accuracy=0.06 loss=114.058
[*]Test Result=0.103200000077 at epoch3
[step100] accuracy=0.26 loss=113.873
[step200] accuracy=0.08 loss=115.774
[step300] accuracy=0.14 loss=114.722
[step400] accuracy=0.1 loss=114.43
[step500] accuracy=0.12 loss=114.766
[step600] accuracy=0.08 loss=116.453
[step700] accuracy=0.02 loss=116.828
[step800] accuracy=0.06 loss=115.831
[step900] accuracy=0.14 loss=114.576
[step1000] accuracy=0.04 loss=114.588
[*]Test Result=0.113499999568 at epoch4
[step100] accuracy=0.24 loss=114.013
[step200] accuracy=0.1 loss=115.269
[step300] accuracy=0.08 loss=115.71
[step400] accuracy=0.18 loss=113.4
[step500] accuracy=0.14 loss=115.153
[step600] accuracy=0.08 loss=114.52
[step700] accuracy=0.12 loss=114.871
[step800] accuracy=0.22 loss=115.017
[step900] accuracy=0.12 loss=113.872
[step1000] accuracy=0.12 loss=115.084
[*]Test Result=0.171800000742 at epoch5
[step100] accuracy=0.12 loss=116.787
[step200] accuracy=0.1 loss=116.283
[step300] accuracy=0.04 loss=115.422
[step400] accuracy=0.14 loss=114.826
[step500] accuracy=0.18 loss=114.08
[step600] accuracy=0.14 loss=114.935
[step700] accuracy=0.18 loss=114.367
[step800] accuracy=0.02 loss=115.996
[step900] accuracy=0.08 loss=114.403
[step1000] accuracy=0.24 loss=113.339
[*]Test Result=0.113499999568 at epoch6
[step100] accuracy=0.18 loss=114.502
[step200] accuracy=0.12 loss=114.226
[step300] accuracy=0.14 loss=114.238
[step400] accuracy=0.28 loss=113.135
[step500] accuracy=0.04 loss=115.067
[step600] accuracy=0.16 loss=113.927
[step700] accuracy=0.1 loss=113.124
[step800] accuracy=0.06 loss=114.841
[step900] accuracy=0.16 loss=113.212
[step1000] accuracy=0.26 loss=112.934
[*]Test Result=0.199200000018 at epoch7
[step100] accuracy=0.16 loss=114.148
[step200] accuracy=0.12 loss=113.84
[step300] accuracy=0.14 loss=112.673
[step400] accuracy=0.2 loss=112.878
[step500] accuracy=0.2 loss=114.386
[step600] accuracy=0.12 loss=112.982
[step700] accuracy=0.38 loss=111.301
[step800] accuracy=0.3 loss=112.395
[step900] accuracy=0.52 loss=110.003
[step1000] accuracy=0.12 loss=111.22
[*]Test Result=0.122199999765 at epoch8
[step100] accuracy=0.08 loss=112.523
[step200] accuracy=0.42 loss=108.418
[step300] accuracy=0.4 loss=105.239
[step400] accuracy=0.5 loss=98.153
[step500] accuracy=0.22 loss=103.485
[step600] accuracy=0.2 loss=104.636
[step700] accuracy=0.48 loss=95.7585
[step800] accuracy=0.24 loss=94.8633
[step900] accuracy=0.38 loss=93.5662
[step1000] accuracy=0.36 loss=89.0528
[*]Test Result=0.351300003231 at epoch9

跑了10个epoch,测试集正确率才到达35%。

加了BN层以后训练效果:

[step100] accuracy=0.1 loss=116.102
[step200] accuracy=0.22 loss=112.854
[step300] accuracy=0.14 loss=115.377
[step400] accuracy=0.1 loss=115.649
[step500] accuracy=0.1 loss=115.625
[step600] accuracy=0.24 loss=114.879
[step700] accuracy=0.1 loss=115.61
[step800] accuracy=0.12 loss=114.699
[step900] accuracy=0.14 loss=115.097
[step1000] accuracy=0.1 loss=114.932
[*]Test Result=0.0974000002816 at epoch0
[step100] accuracy=0.1 loss=116.12
[step200] accuracy=0.06 loss=116.164
[step300] accuracy=0.1 loss=115.818
[step400] accuracy=0.12 loss=115.697
[step500] accuracy=0.18 loss=115.264
[step600] accuracy=0.2 loss=114.414
[step700] accuracy=0.04 loss=115.895
[step800] accuracy=0.12 loss=114.564
[step900] accuracy=0.06 loss=115.524
[step1000] accuracy=0.22 loss=114.622
[*]Test Result=0.161500000656 at epoch1
[step100] accuracy=0.2 loss=115.315
[step200] accuracy=0.14 loss=114.43
[step300] accuracy=0.1 loss=115.918
[step400] accuracy=0.16 loss=114.786
[step500] accuracy=0.26 loss=112.941
[step600] accuracy=0.3 loss=113.985
[step700] accuracy=0.3 loss=112.463
[step800] accuracy=0.14 loss=113.471
[step900] accuracy=0.14 loss=112.914
[step1000] accuracy=0.14 loss=112.23
[*]Test Result=0.24730000034 at epoch2
[step100] accuracy=0.2 loss=111.719
[step200] accuracy=0.32 loss=108.348
[step300] accuracy=0.24 loss=106.837
[step400] accuracy=0.36 loss=102.211
[step500] accuracy=0.32 loss=99.1392
[step600] accuracy=0.42 loss=94.0066
[step700] accuracy=0.5 loss=82.9231
[step800] accuracy=0.5 loss=78.0428
[step900] accuracy=0.56 loss=75.0709
[step1000] accuracy=0.56 loss=72.2615
[*]Test Result=0.569599996507 at epoch3
[step100] accuracy=0.54 loss=72.2187
[step200] accuracy=0.62 loss=62.6503
[step300] accuracy=0.7 loss=51.1989
[step400] accuracy=0.7 loss=50.0574
[step500] accuracy=0.62 loss=48.4715
[step600] accuracy=0.58 loss=56.4319
[step700] accuracy=0.76 loss=48.7727
[step800] accuracy=0.76 loss=39.0827
[step900] accuracy=0.66 loss=44.0735
[step1000] accuracy=0.74 loss=40.6393
[*]Test Result=0.731999999881 at epoch4
[step100] accuracy=0.82 loss=39.1621
[step200] accuracy=0.8 loss=33.0594
[step300] accuracy=0.68 loss=41.5027
[step400] accuracy=0.72 loss=49.6565
[step500] accuracy=0.8 loss=32.1081
[step600] accuracy=0.8 loss=42.5631
[step700] accuracy=0.84 loss=31.7484
[step800] accuracy=0.8 loss=34.406
[step900] accuracy=0.7 loss=36.0701
[step1000] accuracy=0.76 loss=39.4207
[*]Test Result=0.798400003314 at epoch5
[step100] accuracy=0.66 loss=38.2423
[step200] accuracy=0.88 loss=23.5632
[step300] accuracy=0.8 loss=37.7658
[step400] accuracy=0.8 loss=41.1382
[step500] accuracy=0.84 loss=31.7916
[step600] accuracy=0.86 loss=24.6395
[step700] accuracy=0.8 loss=29.7371
[step800] accuracy=0.84 loss=33.4366
[step900] accuracy=0.84 loss=25.56
[step1000] accuracy=0.92 loss=23.0958
[*]Test Result=0.841499999762 at epoch6
[step100] accuracy=0.9 loss=17.4944
[step200] accuracy=0.74 loss=35.0277
[step300] accuracy=0.9 loss=30.2663
[step400] accuracy=0.78 loss=34.679
[step500] accuracy=0.82 loss=25.4055
[step600] accuracy=0.86 loss=19.0345
[step700] accuracy=0.98 loss=14.34
[step800] accuracy=0.86 loss=27.425
[step900] accuracy=0.78 loss=35.237
[step1000] accuracy=0.88 loss=23.2125
[*]Test Result=0.86880000174 at epoch7
[step100] accuracy=0.88 loss=23.3765
[step200] accuracy=0.82 loss=33.0606
[step300] accuracy=0.76 loss=44.3354
[step400] accuracy=0.9 loss=17.5737
[step500] accuracy=0.82 loss=27.3082
[step600] accuracy=0.92 loss=18.8941
[step700] accuracy=0.84 loss=27.9557
[step800] accuracy=0.9 loss=16.8646
[step900] accuracy=0.92 loss=12.3513
[step1000] accuracy=0.9 loss=22.4553
[*]Test Result=0.886900002956 at epoch8

跑了9个epoch差不多有了88%的正确率。粗略估计下同样到达35%正确率,前者需要10个epochs,后者差不多需要3.4个epochs。

快了3倍左右~这个数值和论文上BN-Baseline与Incetion的加速差不多。应该可以通过调整LR变得更快。

代码:
###VGG.PY#########

import tensorflow as tf
"""
(1)构造函数__init__参数
  input_sz:  输入层placeholder的4-D shape,如mnist是[None,28,28,1]
  fc_layers: 全连接层每一层大小,接在卷积层后面。如mnist可以为[128,84,10],[10]
  conv_info: 卷积层、池化层。
    如vgg16可以这样写:[(2,64),(2,128),(3,256),(3,512),(3,512)],表示2+2+3+3+3=13个卷积层,4个池化层,以及channels
(2)train函数:训练一步
   batch_input: 输入的batch
   batch_output: label
   learning_rate:学习率
   返回:正确率和loss值(float)   格式:{"accuracy":accuracy,"loss":loss}
(3)forward:训练后用于测试
(4)save(save_path,steps)保存模型
(5)restore(path):从文件夹中读取最后一个模型
(6)loss函数使用cross-entrop one-hot版本:y*log(y_net)
(7)optimizer使用adamoptimier
"""
class VGG:    #VGG分类器
  sess=None
  #Tensor
  input=None 
  output=None
  desired_out=None
  loss=None
  iscorrect=None
  accuracy=None
  optimizer=None
  param_num=0             #参数个数
  #参数
  learning_rate=None 
  MOMENTUM         = 0.9
  WEIGHT_DECAY     = 1e-4       #L2 REGULARIZATION
  ACTIVATE         = None
  CONV_PADDING     = "SAME"
  MAX_POOL_PADDING = "SAME"
  CONV_WEIGHT_INITAILIZER = tf.truncated_normal_initializer(stddev=0.1)
  CONV_BIAS_INITAILIZER   = tf.constant_initializer(value=0.0)
  FC_WEIGHT_INITAILIZER   = tf.truncated_normal_initializer(stddev=0.1)
  FC_BIAS_INITAILIZER     = tf.constant_initializer(value=0.0)
  
  
  def train(self,batch_input,batch_output,learning_rate):  
    _,accuracy,loss=self.sess.run([self.optimizer,self.accuracy,self.loss],
       feed_dict={self.input:batch_input,self.desired_out:batch_output,self.learning_rate:learning_rate})
    return {"accuracy":accuracy,"loss":loss}
    
  def forward(self,batch_input):
    return self.sess.run(self.output,feed_dict={self.input:batch_input})
  
  def save(self,save_path,steps):
    saver=tf.train.Saver(max_to_keep=5)
    saver.save(self.sess,save_path,global_step=steps)
  def restore(self,restore_path):
    path=tf.train.latest_checkpoint(restore_path)
    print("[*]Restore from %s" %(path))
    if path==None:
      return False
    saver=tf.train.Saver(max_to_keep=5)
    saver.restore(self.sess,path)
    return True
    
  def bn(self,x,name="bn"):
    #return x
    axes = [d for d in range(len(x.get_shape()))]
    beta = self._get_variable("beta", shape=[],initializer=tf.constant_initializer(0.0))
    gamma= self._get_variable("gamma",shape=[],initializer=tf.constant_initializer(1.0))
    x_mean,x_variance=tf.nn.moments(x,axes)  
    y=tf.nn.batch_normalization(x,x_mean,x_variance,beta,gamma,1e-10,name)
    return y
    
  def get_optimizer(self):    #
    #Optimizer 
    #sself.optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss)
    #self.optimizer =tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss) #1300 steps后达到误差范围。  
    self.optimizer =tf.train.MomentumOptimizer(self.learning_rate,self.MOMENTUM).minimize(self.loss)            #9000 steps后达到误差范围。  
  
  #对x执行一次卷积操作+Relu
  def conv(self,x,name,channels,ksize=3):
    x_shape=x.get_shape()
    x_channels=x_shape[3].value
    weight_shape=[ksize,ksize,x_channels,channels]
    bias_shape=[channels]
    weight = self._get_variable("weight",weight_shape,initializer=self.CONV_WEIGHT_INITAILIZER)
    bias   = self._get_variable("bias",bias_shape,initializer=self.CONV_BIAS_INITAILIZER) 
    y=tf.nn.conv2d(x,weight,strides=[1,1,1,1],padding=self.CONV_PADDING,name=name)
    y=tf.add(y,bias,name=name)
    return y
  
  def max_pool(self,x,name):
    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding=self.MAX_POOL_PADDING,name=name)
    
  #定义_get_variable方便进行l2_regularization以及其他一些操作
  def _get_variable(self,name,shape,initializer):
    param=1
    for i in range(0,len(shape)):
      param*=shape[i]
    self.param_num+=param
    
    if self.WEIGHT_DECAY>0:
      regularizer=tf.contrib.layers.l2_regularizer(self.WEIGHT_DECAY)
    else:
      regularizer=None 
    
    return tf.get_variable(name,
                           shape=shape,
                           initializer=initializer,
                           regularizer=regularizer)
                           
  def fc(self,x,num,name):
    x_num=x.get_shape()[1].value
    weight_shape=[x_num,num]
    bias_shape  =[num]
    weight=self._get_variable("weight",shape=weight_shape,initializer=self.FC_WEIGHT_INITAILIZER)
    bias  =self._get_variable("bias",shape=bias_shape,initializer=self.FC_BIAS_INITAILIZER)
    y=tf.add(tf.matmul(x,weight),bias,name=name)
    return y 
  def _loss(self): 
    cross_entropy=-tf.reduce_sum(self.desired_out*tf.log(tf.clip_by_value(self.output,1e-10,1.0)))
    regularization_losses=tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    self.loss = tf.add_n([cross_entropy]+regularization_losses)
    #tf.scalar_summary('loss', loss_)
    return self.loss
    
  def __init__(self,input_sz,fc_layers,conv_info=[],activate_fun=tf.nn.relu): #
    self.ACTIVATE=activate_fun
    self.param_num=0  #返回参数个数
    self.sess=tf.Session()
    layers=[]
    #(1)placeholder定义(输入、输出、learning_rate)
    #input
    self.input=tf.placeholder(tf.float32,input_sz,name="input") 
    layers.append(self.input)
    #
    layers.append(self.bn(layers[-1]))
    
    #output
    output_sz=[None,fc_layers[-1]]
    self.desired_out=tf.placeholder(tf.float32,output_sz,name="desired_out")
    self.learning_rate=tf.placeholder(tf.float32,name="learning_rate")
    
    
    #(2)插入卷积层+池化层
    with tf.variable_scope("convolution"):
      conv_block_id=0
      for cur_layers in conv_info:  
      #添加卷积层block
        with tf.variable_scope("conv_block_%d" %(conv_block_id)) as scope:
          cur_conv_num=cur_layers[0]   #cur_conv_num个卷积层叠放
          cur_channels=cur_layers[1]   #每个卷积层的通道
          #cur_conv_num个卷积层叠加 
          for conv_id in range(0,cur_conv_num): 
            with tf.variable_scope("conv_%d" %(conv_id)):
                #添加卷积层  
                x=layers[-1] 
                """
                #顺序一:x->bn->weight->relu
                x2=self.bn(x) 
                x3=self.conv(x2,channels=cur_channels,name="conv")
                x4=self.ACTIVATE(x3)
                """
                
                #"""
                #顺序二: x->bn->relu->weight
                x2=self.bn(x)
                x3=self.ACTIVATE(x2)
                x4=self.conv(x3,channels=cur_channels,name="conv")  
                #"""
                
                """
                #顺序三:x->weight->bn->relu
                x2=self.conv(x,channels=cur_channels,name="conv")
                x3=self.bn(x2)
                x4=self.ACTIVATE(x3)
                """
                layers.append(x4) 
          #每个卷积块后是pool层   
          last_layer=layers[-1]
          pool=self.max_pool(last_layer,"max_pool")
          layers.append(pool) 
          conv_block_id+=1
      
    #(3)卷积层flatten
    last_layer=layers[-1]
    last_shape=last_layer.get_shape()
    neu_num=1
    for dim in range(1,len(last_shape)): 
       neu_num*=last_shape[dim].value
    flat_layer=tf.reshape(last_layer,[-1,neu_num],name="flatten")
    layers.append(flat_layer) 
    
    #(4)全连接层 #!!!!!!!!!最后一层不要加上relu!!!!!!
    with tf.variable_scope("full_connection"): 
        for fc_id in range(0,len(fc_layers)):
           with tf.variable_scope("fc_%d" %(fc_id)):
              num=fc_layers[fc_id]
              x=layers[-1]
              x2=self.bn(x)
              x3=self.ACTIVATE(x,name="relu")
              y=self.fc(x3,num,"fc")          
              layers.append(y)  
      
    #(5)softmax和loss函数
    self.output=tf.nn.softmax(layers[-1])
    #loss函数
    self._loss()
    #(6)辅助信息:正确率
    self.iscorrect=tf.equal(tf.argmax(self.desired_out,1),tf.argmax(self.output,1),name="iscorrect")
    self.accuracy=tf.reduce_mean(tf.cast(self.iscorrect,dtype=tf.float32),name="accuracy")
    #(7)优化器和 variables初始化
    self.get_optimizer()
    self.sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter("./tboard/",self.sess.graph)  
  def __del__(self):
    self.sess.close()
####VGG_MNIST.PY####
import VGG
import tensorflow as tf
import sys
from tensorflow.examples.tutorials.mnist import input_data
vgg=VGG.VGG([None,28,28,1],[100,100,100,10],activate_fun=tf.sigmoid)#,[(3,16),(3,32),(3,64),(3,128)])
#vgg=VGG.VGG([None,28,28,1],[10],[(2,16),(2,32),(2,64)])
print("param_num=%d" %(vgg.param_num))
#writer = tf.summary.FileWriter("./tboard/",vgg.sess.graph)
mnist = input_data.read_data_sets("input_data", one_hot=True) 
def get_mnist_batch(num,get_test=False):
  batch=None 
  if get_test:
    batch=[mnist.test.images,mnist.test.labels]
  else:
    batch=mnist.train.next_batch(num)
    
  input=[]
  for x in batch[0]:
    inp=[[0 for _ in range(0,28)] for _ in range(0,28)]
    for row in range(0,28):
      for col in range(0,28):
        inp[row][col]=[x[row*28+col]]
        """
        if inp[row][col][0]>0.6:
          print(" ",end="")
        else:
         if inp[row][col][0]>0.3:
           print(".",end="")
         else:
           print("w",end="")
        if col==27:
          print("")
    sys.exit(0)
    """
    
    input.append(inp) 
  return input,batch[1]
  
def get_mnist_test_accuracy():
  batch=get_mnist_batch(0,True)
  accuracy=0
  for st in range(0,10000,100):
    ret=vgg.train(batch[0][st:st+100],batch[1][st:st+100],learning_rate=0)
    accuracy+=ret["accuracy"]/100
  return accuracy
"""    
if vgg.restore("./model/"):
  test_acc=get_mnist_test_accuracy()
  print("[*]Test Result=%s at epoch%d" %(test_acc,0))
"""  
learning_rate=1e-4
for epoch in range(0,10):
  batch_sz=50   
  for i in range(int(50000/batch_sz)):
    batch = get_mnist_batch(batch_sz) 
    ret=vgg.train(batch[0],batch[1],learning_rate=learning_rate)
    if i%100==0: 
      #print(batch[1][0])
      #print(ret[0][0])
      print("[step%d] accuracy=%s loss=%s" %(i+100,ret["accuracy"],ret["loss"]))
  #learning_rate/=2
  vgg.save("model/mnist_epoch",epoch)
  test_acc=get_mnist_test_accuracy()
  print("[*]Test Result=%s at epoch%d" %(test_acc,epoch))