Tensorflow "hello world"源码分析

“hello world” 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import tensorflow as tf
# 创建一个常量 op, 产生一个 1x2 矩阵. 这个 op 被作为一个节点
# 加到默认图中.
#
# 构造器的返回值代表该常量 op 的返回值.
matrix1 = tf.constant([[3., 3.]])
# 创建另外一个常量 op, 产生一个 2x1 矩阵.
matrix2 = tf.constant([[2.],[2.]])
# 创建一个矩阵乘法 matmul op , 把 'matrix1' 和 'matrix2' 作为输入.
# 返回值 'product' 代表矩阵乘法的结果.
product = tf.matmul(matrix1, matrix2)
# 启动默认图.
sess = tf.Session()
# 调用 sess 的 'run()' 方法来执行矩阵乘法 op, 传入 'product' 作为该方法的参数.
# 上面提到, 'product' 代表了矩阵乘法 op 的输出, 传入它是向方法表明, 我们希望取回
# 矩阵乘法 op 的输出.
#
# 整个执行过程是自动化的, 会话负责传递 op 所需的全部输入. op 通常是并发执行的.
#
# 函数调用 'run(product)' 触发了图中三个 op (两个常量 op 和一个矩阵乘法 op) 的执行.
#
# 返回值 'result' 是一个 numpy `ndarray` 对象.
result = sess.run(product)
print result
# ==> [[ 12.]]
# 任务完成, 关闭会话.
sess.close()

一. 构建Const Tensor

1
matrix1 = tf.constant([[3., 3.]])
  1. 获取global default_graph,图在初始时会获取所有已注册的operator
    1
    self._registered_ops = op_def_registry.get_registered_ops()

operator的注册由gen_array_ops.py完成

Note: gen_array_ops.py是编译tensorflow时动态生成的文件,文件内容来自tensorflow\python\framework\python_op_gen.cc

这种动态生成文件的好处在哪儿那? 是为了可以生成包含custom operator的gen_array_ops.py?

  1. 将[[3., 3.]] 转换成Tensor
  • 将python的 list 转换成numpy ndarray
  • 构建TensorProto对象, 指定dtypetensor_shapetensor_context三个属性,分别代表Tensor的数据类型对应的enum(省空间,如float32对应值1)、形状(如 1 * 2 )、值对应的可序列化字符串(即nparray.toString() )

    proto格式可用于传输与保存:python -> c++,设备之间传输;也可用于保存与加载

  1. 创建Const operator

    • 构建NodeDef,指定name、attr(即dtype、value)
    • 根据output_types构建List[Tensor]
  2. 返回ConstOperator.output[0]作为Const Tensor

二. 构建矩阵乘法Tensor

1
product = tf.matmul(matrix1, matrix2)
  1. 也是先将输入转换成Tensor(此处matrix1、matrix2已是tensor)
  2. 构建matmul op,自动推导output tensors的数量与类型

    所有op的参数定义都位于文件E:\tensorflow-master\tensorflow\core\ops\ops.pbtxt

  3. 返回output tensor

三. 启动Session

1
sess = tf.Session()
  1. Tensorflow里Python相当于Client,C++ 为Server,session的实际运行位于c++。 通过Swig,python可调用C++的function。
  2. 根据配置信息获取session
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    Status NewSession(const SessionOptions& options, Session** out_session) {
    SessionFactory* factory;
    Status s = SessionFactory::GetFactory(options, &factory);
    if (!s.ok()) {
    *out_session = nullptr;
    LOG(ERROR) << s;
    return s;
    }
    *out_session = factory->NewSession(options);
    if (!*out_session) {
    return errors::Internal("Failed to create session.");
    }
    return Status::OK();
    }

若target为空,则为DirectSession,代表local模式
若target指定了ip:port列表,则为集群模式

四. 执行图计算

1
result = sess.run(product)
  1. 将tensor对应的python数据结构转换成c++格式的
  2. 调用directSession.run() (实际执行,包含图的构建、剪枝、优化,基于设备的分区等等,后续进行长文分析)