最近在读《TensorFlow 内核剖析》这本书,作者刘光聪。有一些收获,记录一下。

TF的Session

Session是TensorFlow前后端连接的桥梁。用户利用session使得client能够与master的执行引擎建立连接,并通过session.run()来触发一次计算。它建立了一套上下文环境,封装了operation计算以及tensor求值的环境。

session之间采用共享graph的方式来提高运行效率。一个session只能运行一个graph实例,但一个graph可以运行在多个session中。创建session时如果不指定Graph实例,则会使用系统默认Graph。当session close时,默认 graph 引用计数减1。只有引用计数为0时,graph才会被回收。这种graph共享的方式,大大减少了graph创建和回收的资源消耗,优化了TensorFlow运行效率。

op运算和tensor求值时,如果没有指定运行在哪个session中,则会运行在默认session中。通过session.as_default()可以将自己设置为默认session。

operation.run()
tensor.eval()

实际执行的代码是

tf.get_default_session().run(operation)
tf.get_default_session().run(tensor)

Session 类型

前端 Session

分为普通Session和交互式InteractiveSession, 区别在于:

  • InteractiveSession创建后,会将自己替换为默认session。使得之后operation.run()和tensor.eval()的执行通过这个默认session来进行。特别适合Python交互式环境。

  • InteractiveSession自带with上下文管理器。它在创建时和关闭时会调用上下文管理器的enter和exit方法,从而进行资源的申请和释放,避免内存泄漏问题。这同样很适合Python交互式环境。

BaseSession基本包含了所有的会话实现逻辑。包括会话的整个生命周期,也就是创建 执行 关闭和销毁四个阶段。

BaseSession包含的主要成员变量有:

  • graph引用
  • 序列化的graph_def
  • 要连接的tf引擎target
  • session配置信息config

后端Session

后端master中,根据前端client调用tf.Session(target=’’, graph=None, config=None)时指定的target,来创建不同的Session。target为要连接的tf后端执行引擎,默认为空字符串。Session创建采用了抽象工厂模式,如果为空字符串,则创建本地DirectSession,如果以grpc://开头,则创建分布式GrpcSession。

DirectSession只能利用本地设备,将任务创建到本地的CPU GPU上。而GrpcSession则可以利用远端分布式设备,将任务创建到其他机器的CPU GPU上,然后通过grpc协议进行通信。

Session 生命周期

Session作为前后端连接的桥梁,以及上下文运行环境,其生命周期尤其关键。大致分为4个阶段

  • 创建:通过tf.Session()创建session实例,进行系统资源分配,特别是graph引用计数加1
  • 运行:通过session.run()触发计算的执行,client会将整图graph传递给master,由master进行执行
  • 关闭:通过session.close()来关闭,会进行系统资源的回收,特别是graph引用计数减1.
  • 销毁:Python垃圾回收器进行GC时,调用session.del()进行回收。

graph

可以显示创建Graph,并调用as_default()使他替换默认Graph。在该上下文管理器中创建的op都会注册到这个graph中。退出上下文管理器后,则恢复原来的默认graph。一般情况下,我们不用显式创建Graph,使用系统创建的那个默认Graph即可。

with tf.Graph().as_default() as g:
    print tf.get_default_graph() is g
    print tf.get_default_graph()

print tf.get_default_graph()

在上下文管理器中,当前线程的默认图被替换了,而退出上下文管理后,则恢复为了原来的默认图。

graph 类型

前端graph 类型

Python前端中,Graph的数据结构。Graph主要的成员变量是Operation和Tensor。Operation是Graph的节点,它代表了运算算子。Tensor是Graph的边,它代表了运算数据。

@tf_export("Graph")
class Graph(object):
    def __init__(self):
   	    # 加线程锁,使得注册op时,不会有其他线程注册op到graph中,从而保证共享graph是线程安全的
        self._lock = threading.Lock()
        
        # op相关数据。
        # 为graph的每个op分配一个id,通过id可以快速索引到相关op。故创建了_nodes_by_id字典
        self._nodes_by_id = dict()  # GUARDED_BY(self._lock)
        self._next_id_counter = 0  # GUARDED_BY(self._lock)
        # 同时也可以通过name来快速索引op,故创建了_nodes_by_name字典
        self._nodes_by_name = dict()  # GUARDED_BY(self._lock)
        self._version = 0  # GUARDED_BY(self._lock)
        
        # tensor相关数据。
        # 处理tensor的placeholder
        self._handle_feeders = {}
        # 处理tensor的read操作
        self._handle_readers = {}
        # 处理tensor的move操作
        self._handle_movers = {}
        # 处理tensor的delete操作
        self._handle_deleters = {}

graph 添加 op 是会保证线程安全的。

  def _add_op(self, op):
    # graph被设置为final后,就是只读的了,不能添加op了。
    self._check_not_finalized()
    
    # 保证共享graph的线程安全
    with self._lock:
      # 将op以id和name分别构建字典,添加到_nodes_by_id和_nodes_by_name字典中,方便后续快速索引
      self._nodes_by_id[op._id] = op
      self._nodes_by_name[op.name] = op
      self._version = max(self._version, op._id)

name_scope

name_scope 节点命名空间
使用name_scope对graph中的节点进行层次化管理,上下层之间通过斜杠分隔。

后端Graph

Graph

class Graph {
     private:
      // 所有已知的op计算函数的注册表
      FunctionLibraryDefinition ops_;

      // GraphDef版本号
      const std::unique_ptr<VersionDef> versions_;

      // 节点node列表,通过id来访问
      std::vector<Node*> nodes_;

      // node个数
      int64 num_nodes_ = 0;

      // 边edge列表,通过id来访问
      std::vector<Edge*> edges_;

      // graph中非空edge的数目
      int num_edges_ = 0;

      // 已分配了内存,但还没使用的node和edge
      std::vector<Node*> free_nodes_;
      std::vector<Edge*> free_edges_;
 }

后端中的Graph主要成员也是节点node和边edge。节点node为计算算子Operation,边为算子所需要的数据,或者代表节点间的依赖关系。这一点和Python中的定义相似。边Edge的持有它的源节点和目标节点的指针,从而将两个节点连接起来。

Edge

class Edge {
     private:
      Edge() {}

      friend class EdgeSetTest;
      friend class Graph;
      // 源节点, 边的数据就来源于源节点的计算。源节点是边的生产者
      Node* src_;

      // 目标节点,边的数据提供给目标节点进行计算。目标节点是边的消费者
      Node* dst_;

      // 边id,也就是边的标识符
      int id_;

      // 表示当前边为源节点的第src_output_条边。源节点可能会有多条输出边
      int src_output_;

      // 表示当前边为目标节点的第dst_input_条边。目标节点可能会有多条输入边。
      int dst_input_;
};

Edge既可以承载tensor数据,提供给节点Operation进行运算,也可以用来表示节点之间有依赖关系。对于表示节点依赖的边,其src_output_, dst_input_均为-1,此时边不承载任何数据。

Node

class Node {
 public:
    // NodeDef,节点算子Operation的信息,比如op分配到哪个设备上了,op的名字等,运行时有可能变化。
  	const NodeDef& def() const;
    
    // OpDef, 节点算子Operation的元数据,不会变的。比如Operation的入参列表,出参列表等
  	const OpDef& op_def() const;
 private:
  	// 输入边,传递数据给节点。可能有多条
  	EdgeSet in_edges_;

  	// 输出边,节点计算后得到的数据。可能有多条
  	EdgeSet out_edges_;
}

创建Node时不需要new OpDef,只需要从OpDef仓库中取出即可。因为元信息是确定的,比如Operation的入参个数等。

由Node和Edge,即可以组成图Graph,通过任何节点和任何边,都可以遍历完整图。Graph执行计算时,按照拓扑结构,依次执行每个Node的op计算,最终即可得到输出结果。入度为0的节点,也就是依赖数据已经准备好的节点,可以并发执行,从而提高运行效率。

系统中存在默认的Graph,初始化Graph时,会添加一个Source节点和Sink节点。Source表示Graph的起始节点,Sink为终止节点。Source的id为0,Sink的id为1,其他节点id均大于1.

参考资料

Graph_谢杨易的博客-CSDN博客

Session_谢杨易的博客-CSDN博客