博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
学习TensorFlow,TensorBoard可视化网络结构和参数
阅读量:5325 次
发布时间:2019-06-14

本文共 7173 字,大约阅读时间需要 23 分钟。

在学习深度网络框架的过程中,我们发现一个问题,就是如何输出各层网络参数,用于更好地理解,调试和优化网络?针对这个问题,TensorFlow开发了一个特别有用的可视化工具包:TensorBoard,既可以显示网络结构,又可以显示训练和测试过程中各层参数的变化情况。本博文分为四个部分,第一部分介绍相关函数,第二部分是代码测试,第三部分是运行结果,第四部分介绍相关参考资料。

一. 相关函数

TensorBoard的输入是tensorflow保存summary data的日志文件。日志文件名的形式如:events.out.tfevents.1467809796.lei-All-Series 或 events.out.tfevents.1467809800.lei-All-Series。TensorBoard可读的summary data有scalar,images,audio,histogram和graph。那么怎么把这些summary data保存在日志文件中呢?

数值如学习率,损失函数用scalar_summary函数。tf.scalar_summary(节点名称,获取的数据)

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))  tf.scalar_summary('accuracy', accuracy)

各层网络权重,偏置的分布,用histogram_summary函数

preactivate = tf.matmul(input_tensor, weights) + biases  tf.histogram_summary(layer_name + '/pre_activations', preactivate)

其他几种summary data也是同样的方式获取,只是对应的获取函数名称换一下。这些获取summary data函数节点和graph是独立的,调用的时候也需要运行session。当需要获取的数据较多的时候,我们一个一个去保存获取到的数据,以及一个一个去运行会显得比较麻烦。tensorflow提供了一个简单的方法,就是合并所有的summary data的获取函数,保存和运行只对一个对象进行操作。比如,写入默认路径中,比如/tmp/mnist_logs (by default)

merged = tf.merge_all_summaries()  train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train', sess.graph)  test_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/test')

SummaryWriter从tensorflow获取summary data,然后保存到指定路径的日志文件中。以上是在建立graph的过程中,接下来执行,每隔一定step,写入网络参数到默认路径中,形成最开始的文件:events.out.tfevents.1467809796.lei-All-Series 或 events.out.tfevents.1467809800.lei-All-Series。

for i in range(FLAGS.max_steps):  if i % 10 == 0:  # Record summaries and test-set accuracysummary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False))        test_writer.add_summary(summary, i)   print('Accuracy at step %s: %s' % (i, acc))   else: # Record train set summarieis, and train      summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True))        train_writer.add_summary(summary, i)



tensorflow 可视化

tensorflow的可视化是使用summarytensorboard合作完成的.

基本用法

首先明确一点,summary也是op.

输出网络结构

with tf.Session() as sess:  writer = tf.summary.FileWriter(your_dir, sess.graph)
  • 1
  • 2
  • 1
  • 2

命令行运行tensorboard --logdir=your_dir,然后浏览器输入127.0.1.1:6006

这样你就可以在tensorboard中看到你的网络结构图了

可视化参数

#opsloss = ...tf.summary.scalar("loss", loss)merged_summary = tf.summary.merge_all()init = tf.global_variable_initializer()with tf.Session() as sess:  writer = tf.summary.FileWriter(your_dir, sess.graph)  sess.run(init)  for i in xrange(100):    _,summary = sess.run([train_op,merged_summary], feed_dict)    writer.add_summary(summary, i)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

这时,打开tensorboard,在EVENTS可以看到loss随着i的变化了,如果看不到的话,可以在代码最后加上writer.flush()试一下,原因后面说明。

函数介绍

  • tf.summary.merge_all: 将之前定义的所有summary op整合到一起

  • FileWriter: 创建一个file writer用来向硬盘写summary数据,

  • tf.summary.scalar(summary_tags, Tensor/variable): 用于标量的 summary

  • tf.summary.image(tag, tensor, max_images=3, collections=None, name=None):tensor,必须4维,形状[batch_size, height, width, channels],max_images(最多只能生成3张图片的summary),觉着这个用在卷积中的kernel可视化很好用.max_images确定了生成的图片是[-max_images: ,height, width, channels],还有一点就是,TensorBord中看到的image summary永远是最后一个global step

  • tf.summary.histogram(tag, values, collections=None, name=None):values,任意形状的tensor,生成直方图summary

  • tf.summary.audio(tag, tensor, sample_rate, max_outputs=3, collections=None, name=None)

FileWriter

注意:add_summary仅仅是向FileWriter对象的缓存中存放event data。而向disk上写数据是由FileWrite对象控制的。下面通过FileWriter的构造函数来介绍这一点!!!

tf.summary.FileWriter.__init__(logdir, graph=None, max_queue=10, flush_secs=120, graph_def=None)Creates a FileWriter and an event file.# max_queue: 在向disk写数据之前,最大能够缓存event的个数# flush_secs: 每多少秒像disk中写数据,并清空对象缓存
  • 1
  • 2
  • 3
  • 4
  • 5
  • 1
  • 2
  • 3
  • 4
  • 5

注意

  1. 如果使用writer.add_summary(summary,global_step)时没有传global_step参数,会使scarlar_summary变成一条直线。

  2. 只要是在计算图上的Summary op,都会被merge_all捕捉到, 不需要考虑变量生存空间问题!

  3. 如果执行一次,disk上没有保存Summary数据的话,可以尝试下file_writer.flush()

小技巧

如果想要生成的summary有层次的话,记得在summary外面加一个name_scope

with tf.name_scope("summary_gradients"):    tf.summary.histgram("name", gradients)
  • 1
  • 2
  • 1
  • 2

这样,tensorboard在显示的时候,就会有一个sumary_gradients一集目录。


主旨

在TensorFlow中每开发一个模型,都可以使用可视化调试工具TensorBoard得到这个session的Graph,这张图的结构和内容都不同于机器学习教材上介绍的典型神经网络结构图。本文试图通过代码实验理解Graph的含义,用以指导日常调试。

代码和运行环境

代码: singleNerualNode.py

运行环境:

Python 2.7.12 (default, Nov 19 2016, 06:48:10)
[GCC 5.4.0 20160609] on linux2
>>> tf.__version__
‘1.0.0-rc2’

问题介绍

在TensorFlow开发中,是一项很有用的可视化调试工具。在TensorBoard中,除了开发者自定义输出的数据结构之外,还包括表征神经网络模型的GRAPH。

在常见的机器学习教材中,神经网络的结构一般通过类似于如下的图形表示,下图引用自

这里写图片描述

但是,TensorBoard生成的Graph与上述形态完全不同,我自己开发的某个神经网络生成的Graph如下图所示

这里写图片描述

如何理解这张图?本文试图通过代码实验做一些尝试。

案例1:单神经元

为了简化分析场景,我们设计一个由单个神经元构成的神经网络,这个神经元存在numberOfInputDims输入,神经元的每条输入边都有权重因子wi,此外神经元还有bias偏置项,激活函数为sigmoid. 这个结构可以用如下的结构图描述

这里写图片描述

在TensorFlow中,如下代码即可定义出满足上述结构的神经网络,其中a就是上图中的Y

inputTensor = tf.placeholder(tf.float32, [None, numberOfInputDims], name='inputTensor')  labelTensor=tf.placeholder(tf.float32, [None, 1], name='LabelTensor')  W = tf.Variable(tf.random_uniform([numberOfInputDims, 1], -1.0, 1.0), name='weights')  b = tf.Variable(tf.zeros([1]), name='biases')  a = tf.nn.sigmoid(tf.matmul(inputTensor, W) + b, name='activation')
  • 1
  • 2
  • 3
  • 4
  • 5

使用TensorBoard生成的Graph如下图所示

Fig 0

问题1:Graph中的边(Edge)代表什么

根据TF官方文档对于,上图中实线表示数据依赖(tensor在运算符之间的流动关系),而虚线表示控制依赖关系,原文引用如下

TensorFlow graphs have two kinds of connections: data dependencies and control dependencies. Data dependencies show the flow of tensors between two ops and are shown as solid arrows, while control dependencies use dotted lines.

但是,我们得到的Graph中所有的实线都没有标出箭头方向,他们之间是谁依赖谁呢?

回答这个问题需要回到代码中,从代码可以知道:weight是由tf.random_uniform([numberOfInputDims, 1]初始化得到的;weight和inputTensor做矩阵乘法得到中间变量;中间变量再加上bias得到激活函数的输入;以此类推。

因此,TensorBoard Graph的上下方位代表了数据依赖的方向:数据总是从下方的节点流向上方的节点,上方节点依赖于下方节点

接下来讨论控制依赖。

上图中weight和bias节点都存在依赖于init运算的虚线,这说明weight和bias节点都需要初始化。虚线指向的运算符(op)是被依赖的运算符(op)

问题2:Graph中的节点代表什么

官方文档对于给出了如下图例表格

这里写图片描述

从中可以看出,不考虑summary node的情况下,节点要么是常数Constant,要么是运算符Operation Node,要么是前两者的组合。

回到我们这个具体问题,random_uniform/weight/bias就是组合节点,而其他节点就是运算符。注意:从上图可以看到,inputTensor也是被视作一个独立的运算符。

放大其中一个节点weights,我们可以看到其内部结构如下,包含:赋值(assign)运算符、(weight)运算符、读取(read)运算符

这里写图片描述

结合上述实验和分析,可以初步判断运算符Operation Node包含以下情况

  • 具体的运算操作:例如矩阵乘法(上图中的matmul),赋值(assign),读取(read)
  • 某个Tensor本身:例如上图中的inputTensor和(weights)节点

关于Operation Node代表常量的情况,在本文的案例2中会有体现

案例2:单神经元+损失函数+误差反向传播梯度下降调整参数

案例1中的神经网络只包含了前向计算逻辑,作为神经网络,其最主要的功能在于被训练满足某一目标,因此损失函数和误差反向传播梯度下降调整参数是必不可少的。

在此前的代码基础上新增如下两行就可以实现上述功能。这里损失函数使用预测值与目标值的

loss = tf.nn.l2_loss(a - labels, name='L2Loss')  train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
  • 1
  • 2

加入这两个功能后,新生成的Graph如下所示

这里写图片描述

对比上图和案例1中的Graph,在案例2 Graph的Main View(图像左侧)新增了如下信息

  • 节点
    • 常量y:代表label
    • sub运算符
    • L2Loss Tensor
    • gradient命名空间,这个命名空间与除了bias/add/L2Loss之外所有节点都有数据依赖
    • Weights和Bias都新增了对Adam节点的控制依赖

接下来重点讨论Gradient(梯度)和Adam(寻优算法)节点的内部结构

Gradient(梯度)节点

这里写图片描述
Gradient(梯度)节点内部是一个链式结构,分为对L2Loss求导、对减法求导、对激活函数求导……
总结起来,这就是数学上求导数的链式法则的图形化表示,我们的最终目的是求出损失函数对某个参数的导数,那么根据链式法则,只要从损失函数L2Loss出发,将每一层求导结果相乘就可以得到,笔者前一个系列利用Python实现神经网络中也使用了

Adam(寻优算法)节点

这里写图片描述
Adam算法本身的原理和实现可以参阅笔者此前的文章:
上述图形就是根据Adam算法原理,使用梯度、beta1\beta2,结合每一轮训练中weights和bias的原始值,计算weight和bias的更新值,通过控制依赖关系进一步调节weights和bias

总结

TensorBoard中的Graph不同于一般的神经网络结构图,它是一种计算图。图中每个节点要么是Tensor本身,要么是运算符,每一条边要么代表Tensor的流动,要么代表控制关系。这张图完备的表达了通过代码定义的神经网络中所有计算步骤,可以据此说明前向计算、误差反相传播、梯度下降调整参数等过程。

在实际工作中,理解了上述含义,就可以将Graph利用起来,在Debug过程中可视化的发现网络计算流程中的问题。复杂程序的调试总是困难的,引入可视化工具对于调试效率会非常有帮助。



参考文献:

转载于:https://www.cnblogs.com/yifdu25/p/8270849.html

你可能感兴趣的文章
界面交互之支付宝生活圈pk微信朋友圈
查看>>
字符串比较
查看>>
epoll 技术(转)
查看>>
<转>Shell脚本相关
查看>>
使用FreeMarker加载远程主机上模板文件,比如FTP,Hadoop等(转载)
查看>>
Java的位运算符具体解释实例——与(&amp;)、非(~)、或(|)、异或(^)
查看>>
java 注解 学习
查看>>
[leetcode]403. Frog Jump青蛙过河
查看>>
英语音节知识
查看>>
IEEE 802.15.4协议学习之MAC层
查看>>
AngularJS学习篇(十三)
查看>>
Tableau 学习资料
查看>>
中断和异常
查看>>
lucene 全文检索工具的介绍
查看>>
C# MD5-16位加密实例,32位加密实例
查看>>
无线点餐系统初步构思
查看>>
AJAX
查看>>
前端之CSS
查看>>
List注意点【修改】
查看>>
sqoop导入导出对mysql再带数据库test能跑通用户自己建立的数据库则不行
查看>>