建立网络模型

08 Jun 2020

1. 利用keras API构建模型框架

主要用到的就是 Sequential类

Layer包括了weights和一个激活函数 (a call method, the layer’s forward pass).

1.1 函数定义

from tensorflow.keras import Sequential

#第一层要有输入的shape,后面几层可以直接从第一层推导出尺寸,不必显示列出
conv_net = Sequential([layers.Dense(32,activation = "relu",input_shape = (784,)),
                       layers.Dense(10,activation = "softmax")])

网络层类别

1.2 自定义网络层

一般常见的自定义网络层如下,其中build方法不是必需的,大部分情况下都可以省略:

from tensorflow.keras import layers

class Linear(layers.Layer):
  def __init__(self, units=32, input_dim=32):
    super(Linear, self).__init__()
    w_init = tf.random_normal_initializer()
    self.w = tf.Variable(initial_value=w_init(shape=(input_dim, units),
                                              dtype='float32'),
                         trainable=True)
    b_init = tf.zeros_initializer()
    self.b = tf.Variable(initial_value=b_init(shape=(units,),
                                              dtype='float32'),
                         trainable=True)

  def call(self, inputs):
    return tf.matmul(inputs, self.w) + self.b

x = tf.ones((2, 2))
linear_layer = Linear(4, 2)
y = linear_layer(x)
print(y)

参考:https://tensorflow.google.cn/guide/keras/custom_layers_and_models?hl=zh-cn

2. 模型加载自定义层

当我们自定义网络层并且有效保存模型后,希望使用tf.keras.models.load_model进行模型加载时,可能会报如下错误:

raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
ValueError: Unknown layer: Mylayer

解决方法:

首先,建立一个字典,该字典的键是自定义网络层时设定该层的名字,其值为该自定义网络层的类名,该字典将用于加载模型时使用!

_custom_objects = {
    "Mylayer" :  Mylayer,
}

然后,在tf.keras.models.load_model内传入custom_objects告知如何解析重建自定义网络层,其方法如下:

model = tf.keras.models.load_model("path/to/your/model", custom_objects=_custom_objects)

参考:https://zhuanlan.zhihu.com/p/86886620