将tensorflow 1.x & 2.x转化成onnx文件(以arcface-tf2人脸识别模型为例)_ckpt转onnx-程序员宅基地

技术标签: tf2  # 计算机视觉  机器学习  # python  深度学习  onnx  

将tensorflow 1.x & 2.x转化成onnx文件

一、tensorflow 1.x转化成onnx文件

1、ckpt文件生成

参考 tensorflow实现将ckpt转pb文件

# -*- coding:utf-8 -*-

import tensorflow as tf

# 参考<https://blog.csdn.net/guyuealian/article/details/82218092>
# 声明两个变量
'''
适合tf1.x  (tf2.x保存权重文件时没有meta文件)
checkpoint是检查点文件,文件保存了一个目录下所有的模型文件列表;
model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用。
ckpt.data : 保存模型中每个变量的取值
'''
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
    sess.run(init_op)
    print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
    print("v2:", sess.run(v2))
    saver_path = saver.save(sess, "save/model.ckpt")  # 将模型保存到save/model.ckpt文件
    print("Model saved in file:", saver_path)

之后会在save文件夹下产生4个文件(index,checkpint,meta网络结构文件和权重文件,这里注意tf2在生成ckpt文件时,不会产生后缀名为meta的文件

|-save
	|-checkpoint
	|-model.ckpt.data-00000-of-00001
	|-model.ckpt.index
	|-model.ckpt.meta
2、打印权重参数名称

参考TensorFlow拾遗(一) 打印网络结构与变量

# -*- coding:utf-8 -*-

import tensorflow as tf
import os

'''适合tf1.x版本打印网络权重参数,tf2.x版本会报错AttributeError: module 'tensorflow_core._api.v2.train' has no attribute 'import_meta_graph'''
# 参考 <https://www.cnblogs.com/monologuesmw/p/13303745.html>
def txt_save(data, output_file):
    file = open(output_file, 'a')
    for i in data:
        s = str(i) + '\n'
        file.write(s)
    file.close()

def network_param(input_checkpoint, output_file=None):
    saver = tf.train.import_meta_graph(input_checkpoint + ".meta", clear_devices=True)
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        for i in variables:
            print(i)     # 打印
        txt_save(variables, output_file)  # 保存txt   二选一

if __name__ == '__main__':
    checkpoint_path = './save/model.ckpt'  #tensorFlow 2.0以上,否则会报text_format.Merge(file_content.decode("utf-8"), meta_graph_def),UnicodeDecodeError: 'utf-8' codec can't decode byte 0xcf in position 0: invalid continuation byte
    output_file = 'network_param.txt'
    if not os.path.exists(output_file):
        network_param(checkpoint_path, output_file)

输出结果如下

<tf.Variable 'v1:0' shape=(1, 2) dtype=float32_ref>
<tf.Variable 'v2:0' shape=(2, 3) dtype=float32_ref>
3、ckpt文件转pb

参考 tensorflow实现将ckpt转pb文件

# -*- coding:utf-8 -*-

''' 将CKPT 转换成 PB格式的文件的过程可简述如下'''
#参考<https://blog.csdn.net/guyuealian/article/details/82218092>
'''
1、函数freeze_graph中,最重要的就是要确定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操作,我们需要定义输出结点的名字。因为网络其实是比较复杂的,定义了输出结点的名字,那么freeze的时候就只把输出该结点所需要的子图都固化下来,其他无关的就舍弃掉。因为我们freeze模型的目的是接下来做预测。所以,output_node_names一般是网络模型最后一层输出的节点名称,或者说就是我们预测的目标。
2、在保存的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称,对于鄙人的代码,需要固化的节点只有一个:output_node_names。注意节点名称与张量的名称的区别,例如:“input:0”是张量的名称,而"input"表示的是节点的名称。
3、源码中通过graph = tf.get_default_graph()获得默认的图,这个图就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢复的图,因此必须先执行tf.train.import_meta_graph,再执行tf.get_default_graph() 。
4、实质上,我们可以直接在恢复的会话sess中,获得默认的网络图,更简单的方法,如下:
'''
import tensorflow as tf
from tensorflow import graph_util

'''
通过传入 CKPT 模型的路径得到模型的图和变量数据
通过 import_meta_graph 导入模型中的图
通过 saver.restore 从模型中恢复图中各个变量的数据
通过 graph_util.convert_variables_to_constants 将模型持久化
'''
def freeze_graph(input_checkpoint, output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
    # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径

    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=sess.graph_def,  # 等于:sess.graph_def
            output_node_names=['v1','v2'])  #上面权重文件的参数名称

        with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点

if __name__ == '__main__':
    # 输入ckpt模型路径
    input_checkpoint='./save/model.ckpt'
    # 输出pb模型的路径
    out_pb_path="./save/pb/frozen_model.pb"
    # 调用freeze_graph将ckpt转为pb
    freeze_graph(input_checkpoint,out_pb_path)
4、ckpt文件转onnx(–checkpoint)

注意:ckpt不用先转pb,直接用df2onnx的–checkpoint即可。以下命令在命令行操作,注意需要先在原模型中打印输出网络的输入(inputs)输出(outputs)名称

比如这样子:

参考

'''--checkpoint适用于tf1,不适用于tf2'''
python -m tf2onnx.convert --checkpoint tensorflow2onnx_test/save/model.ckpt.meta --inputs v1:0 --outputs v2:0 --output tensorflow2onnx_test/save/onnx/onnxModel.onnx --opset 9

二、tensorflow 2.x转化成onnx文件

tf2整合了Keras包,tf1上面的代码在tf2中基本无效了,这里参考tf2官方文档

1、ckpt转savemodel(pb)
1)错误用法(不能冻结权重生成pb)

参考https://blog.csdn.net/qq_37116150/article/details/105736728(好吧,冤有头债有主,只是为了避坑)

以下操作在原项目模型代码中进行操作,此步骤可参考实操练习

#将ckpt权重文件转化成pb文件 参考<https://blog.csdn.net/qq_37116150/article/details/105736728>  不太对,导出的pb不适用于tf2onnx
        Convert Keras model to ConcreteFunction
        注意这个Input,是自己定义的输入层名
        full_model = tf.function(lambda Input: model(Input))
        full_model = full_model.get_concrete_function(
            tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

        # Get frozen ConcreteFunction
        frozen_func = convert_variables_to_constants_v2(full_model)
        frozen_func.graph.as_graph_def()

        layers = [op.name for op in frozen_func.graph.get_operations()]
        print("-" * 50)
        print("Frozen model layers: ")
        for layer in layers:
            print(layer)

        print("-" * 50)
        print("Frozen model inputs: ")
        print(frozen_func.inputs)
        print("Frozen model outputs: ")
        print(frozen_func.outputs)

        # Save frozen graph from frozen ConcreteFunction to hard drive
        tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                          logdir="./checkpoints/",
                          name="arc_mbv2.pb",
                          as_text=False)

导出来只有一个pb文件。

2)正确用法(saved_model)

参考

tf.saved_model.save(model,"./checkpoints/pb")

导出来是一个文件夹。

|- pb
	|- assets
    |- variables
    	|-variables.data-00000-of-00001
    	|-variables.index
    |- saved_model.pb
2、pb转onnx
python -m tf2onnx.convert --saved-model RLDD_test/utils/recognition/pb_savemodel/  --output RLDD_test/utils/recognition/onnx/onnxModel1.onnx --opset 9

注意这里不用输入--inputs--outputs参数

输出结果如下:

Skipping registering GPU devices...
2022-03-28 11:07:02.884797: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1096] Device interconnect StreamExecutor with strength 1 edge matrix:
2022-03-28 11:07:02.885144: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1102]      0
2022-03-28 11:07:02.885621: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] 0:   N
2022-03-28 11:07:04.447991: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:814] Optimization results for grappler item: graph_to_optimize
2022-03-28 11:07:04.448224: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:816]   constant_folding: Graph size after: 716 nodes (-274), 1540 edges (-548), t
ime = 784.703ms.
2022-03-28 11:07:04.449014: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:816]   function_optimizer: function_optimizer did nothing. time = 5.174ms.
2022-03-28 11:07:04.449179: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:816]   constant_folding: Graph size after: 716 nodes (0), 1540 edges (0), time =
259.701ms.
2022-03-28 11:07:04.449614: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:816]   function_optimizer: function_optimizer did nothing. time = 9.665ms.
2022-03-28 11:07:50,510 - INFO - Using tensorflow=2.1.0, onnx=1.8.0, tf2onnx=1.9.3/1190aa
2022-03-28 11:07:50,510 - INFO - Using opset <onnx, 9>
2022-03-28 11:11:22,648 - INFO - Computed 0 values for constant folding
2022-03-28 11:14:15,289 - INFO - Optimizing ONNX model
2022-03-28 11:14:19,545 - INFO - After optimization: BatchNormalization -45 (53->8), Cast -1 (1->0), Const -156 (290->134), Identity -6 (6->0), Reshape -17 (18->1),
Transpose -225 (227->2)
2022-03-28 11:14:19,929 - INFO -
2022-03-28 11:14:19,930 - INFO - Successfully converted TensorFlow model RLDD_test/utils/recognition/pb_savemodel/ to ONNX
2022-03-28 11:14:19,931 - INFO - Model inputs: ['input_image']
2022-03-28 11:14:19,932 - INFO - Model outputs: ['OutputLayer']
2022-03-28 11:14:19,933 - INFO - ONNX model is saved at RLDD_test/utils/recognition/onnx/onnxModel1.onnx
3、小总结(★★★)
  1. 要想将ckpt转化成onnx,必须要有模型源码,否则无法确定tf2onnx.convert的inputs和outputs参数(适合tf1.x版本

  2. tf2由于没有meta文件,无法使用–checkpoint进行compile,所以只能先将ckpt转化成pb,注意命名格式为saved_model.pb,否则会报错:

    OSError: SavedModel file does not exist at: RLDD_test/utils/recognition/ckpt//{saved_model.pbtxt|saved_model.pb}
    
  3. tf2中的ckpt转化成savemodel时,会自动生成包含saved_model.pb的文件夹,如果使用错误方法生成pb文件的话,在运行pb转onnx命令时,会报错:

    RuntimeError: MetaGraphDef associated with tags 'serve' could not be found in SavedModel. To inspect available tag-sets in the SavedModel, please use the SavedModel
    CLI: `saved_model_cli` 
    

三、实操练习

任务:下载 arcface-tf2源码 + tf2预训练权重文件(自己之前尝试过训练mobileNetV3 + arcface head进行人脸识别,但是属实训不好,经常“炸炉”,因此这里万分感谢作者能够开源代码并提供预训练的权重文件),将基于tf2的人脸识别项目转移到onnx推理引擎中。
在使用该模型进行推理时,会发现MobileNetV2在使用ArcFace Loss之前需要进行特征和权重的L2归一化。这样再使用点积相似度dot)时,可以将计算的相似度落在0~1区间上,并且相似度越大,则该图像在特征空间上与原图像越近。(具体实例可以参考facex-zoo)

源码中有ResNet50和MobileNetV2两个模型,这里以MobileNetV2模型举个栗子

  1. 根据源码中README.md的提示,导入相应的模型权重文件,运行test.py

  2. 在test.py中加入这段代码,将ckpt模型保存成saveModel的形式

tf.saved_model.save(model,"./checkpoints/pb")
  1. 生成savemodel文件夹之后,使用如下命令将savemodel转化成onnx,注意修改savemodel路径。(此过程比较耗时,请耐心等待)
python -m tf2onnx.convert --saved-model RLDD_test/utils/recognition/pb_savemodel/  --output RLDD_test/utils/recognition/onnx/onnxModel1.onnx --opset 9

这里注意:不单单只有一个pb文件,如果使用冻结权重保存的pb文件,在运行savemodel转化成onnx的命令会报错

  1. 接着使用onnxRuntime推理引擎加载onnx文件,即可实现结果的输出:

参考python关于onnx模型的一些基本操作

import cv2
import onnx
from onnx import helper
import onnxruntime
import numpy as np

if __name__ == '__main__':

    #参考 <https://blog.csdn.net/CFH1021/article/details/108732114>
    '''一、获取onnx模型的输出层'''
    # # 加载模型
    # model = onnx.load('./onnx/onnxModel1.onnx')
    # # 检查模型格式是否完整及正确
    # onnx.checker.check_model(model)
    # # 获取输出层,包含层名称、维度信息
    # output = model.graph.output
    # print(output)

    '''二、获取中节点输出数据'''
    # 加载模型
    # model = onnx.load('./onnx/onnxModel1.onnx')
    # # 创建中间节点:层名称、数据类型、维度信息
    # prob_info = helper.make_tensor_value_info('layer1', onnx.TensorProto.FLOAT, [1, 3, 320, 280])
    # # 将构建完成的中间节点插入到模型中
    # model.graph.output.insert(0, prob_info)
    # # 保存新的模型
    # onnx.save(model, './onnx/onnxModel_new.onnx')

    # 扩展:
    # 删除指定的节点方法: item为需要删除的节点
    # model.graph.output.remove(item)

    '''三、onnx前向InferenceSession的使用'''
    '''
        关于onnx的前向推理,onnx使用了onnxruntime计算引擎。
       onnx runtime是一个用于onnx模型的推理引擎。微软联合Facebook等在2017年搞了个深度学习以及机器学习模型的格式标准–ONNX,顺路提供了一个专门用于ONNX模型推理的引擎(onnxruntime)。
    '''
    #参考 <https://zhuanlan.zhihu.com/p/261307813>
    # 创建一个InferenceSession的实例,并将模型的地址传递给该实例
    sess = onnxruntime.InferenceSession('./onnx/onnxModel1.onnx')
    #加载图片
    img = cv2.imread("../img/calibrate_glasses.jpg")
    img = cv2.resize(img, (112, 112))
    img = img.astype(np.float32) / 255.
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    if len(img.shape) == 3:
        img = np.expand_dims(img, 0)
    # 调用实例sess的run方法进行推理
    outputs = sess.run([], {
     "input_image": img})
    print(outputs)   #模型结果层的输出

补充一点,没必要尝试用cv2.dnn.readNet来加载onnx模型,可能会因为opencv-python版本报如下错误:

cv2.error: OpenCV(4.4.0) C:\Users\appveyor\AppData\Local\Temp\1\pip-req-build-1hg9yufe\opencv\modules\dnn\src\layers\convolution_layer.cpp:94: error: (-213:The function/feature is not implemented) Unsupported asymmetric padding in convolution layer in function 'cv::dnn::BaseConvolutionLayerImpl::BaseConvolutionLayerImpl'

还不如老老实实用onnxRuntime

  1. 结合我之前独家配置的“秘方”(scrfd模型 + 自己构建的人脸数据库进行人脸检测),完成如下人脸识别任务。

在这里插入图片描述
在这里插入图片描述

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_33934427/article/details/123800910

智能推荐

使用nginx解决浏览器跨域问题_nginx不停的xhr-程序员宅基地

文章浏览阅读1k次。通过使用ajax方法跨域请求是浏览器所不允许的,浏览器出于安全考虑是禁止的。警告信息如下:不过jQuery对跨域问题也有解决方案,使用jsonp的方式解决,方法如下:$.ajax({ async:false, url: 'http://www.mysite.com/demo.do', // 跨域URL ty..._nginx不停的xhr

在 Oracle 中配置 extproc 以访问 ST_Geometry-程序员宅基地

文章浏览阅读2k次。关于在 Oracle 中配置 extproc 以访问 ST_Geometry,也就是我们所说的 使用空间SQL 的方法,官方文档链接如下。http://desktop.arcgis.com/zh-cn/arcmap/latest/manage-data/gdbs-in-oracle/configure-oracle-extproc.htm其实简单总结一下,主要就分为以下几个步骤。..._extproc

Linux C++ gbk转为utf-8_linux c++ gbk->utf8-程序员宅基地

文章浏览阅读1.5w次。linux下没有上面的两个函数,需要使用函数 mbstowcs和wcstombsmbstowcs将多字节编码转换为宽字节编码wcstombs将宽字节编码转换为多字节编码这两个函数,转换过程中受到系统编码类型的影响,需要通过设置来设定转换前和转换后的编码类型。通过函数setlocale进行系统编码的设置。linux下输入命名locale -a查看系统支持的编码_linux c++ gbk->utf8

IMP-00009: 导出文件异常结束-程序员宅基地

文章浏览阅读750次。今天准备从生产库向测试库进行数据导入,结果在imp导入的时候遇到“ IMP-00009:导出文件异常结束” 错误,google一下,发现可能有如下原因导致imp的数据太大,没有写buffer和commit两个数据库字符集不同从低版本exp的dmp文件,向高版本imp导出的dmp文件出错传输dmp文件时,文件损坏解决办法:imp时指定..._imp-00009导出文件异常结束

python程序员需要深入掌握的技能_Python用数据说明程序员需要掌握的技能-程序员宅基地

文章浏览阅读143次。当下是一个大数据的时代,各个行业都离不开数据的支持。因此,网络爬虫就应运而生。网络爬虫当下最为火热的是Python,Python开发爬虫相对简单,而且功能库相当完善,力压众多开发语言。本次教程我们爬取前程无忧的招聘信息来分析Python程序员需要掌握那些编程技术。首先在谷歌浏览器打开前程无忧的首页,按F12打开浏览器的开发者工具。浏览器开发者工具是用于捕捉网站的请求信息,通过分析请求信息可以了解请..._初级python程序员能力要求

Spring @Service生成bean名称的规则(当类的名字是以两个或以上的大写字母开头的话,bean的名字会与类名保持一致)_@service beanname-程序员宅基地

文章浏览阅读7.6k次,点赞2次,收藏6次。@Service标注的bean,类名:ABDemoService查看源码后发现,原来是经过一个特殊处理:当类的名字是以两个或以上的大写字母开头的话,bean的名字会与类名保持一致public class AnnotationBeanNameGenerator implements BeanNameGenerator { private static final String C..._@service beanname

随便推点

二叉树的各种创建方法_二叉树的建立-程序员宅基地

文章浏览阅读6.9w次,点赞73次,收藏463次。1.前序创建#include&lt;stdio.h&gt;#include&lt;string.h&gt;#include&lt;stdlib.h&gt;#include&lt;malloc.h&gt;#include&lt;iostream&gt;#include&lt;stack&gt;#include&lt;queue&gt;using namespace std;typed_二叉树的建立

解决asp.net导出excel时中文文件名乱码_asp.net utf8 导出中文字符乱码-程序员宅基地

文章浏览阅读7.1k次。在Asp.net上使用Excel导出功能,如果文件名出现中文,便会以乱码视之。 解决方法: fileName = HttpUtility.UrlEncode(fileName, System.Text.Encoding.UTF8);_asp.net utf8 导出中文字符乱码

笔记-编译原理-实验一-词法分析器设计_对pl/0作以下修改扩充。增加单词-程序员宅基地

文章浏览阅读2.1k次,点赞4次,收藏23次。第一次实验 词法分析实验报告设计思想词法分析的主要任务是根据文法的词汇表以及对应约定的编码进行一定的识别,找出文件中所有的合法的单词,并给出一定的信息作为最后的结果,用于后续语法分析程序的使用;本实验针对 PL/0 语言 的文法、词汇表编写一个词法分析程序,对于每个单词根据词汇表输出: (单词种类, 单词的值) 二元对。词汇表:种别编码单词符号助记符0beginb..._对pl/0作以下修改扩充。增加单词

android adb shell 权限,android adb shell权限被拒绝-程序员宅基地

文章浏览阅读773次。我在使用adb.exe时遇到了麻烦.我想使用与bash相同的adb.exe shell提示符,所以我决定更改默认的bash二进制文件(当然二进制文件是交叉编译的,一切都很完美)更改bash二进制文件遵循以下顺序> adb remount> adb push bash / system / bin /> adb shell> cd / system / bin> chm..._adb shell mv 权限

投影仪-相机标定_相机-投影仪标定-程序员宅基地

文章浏览阅读6.8k次,点赞12次,收藏125次。1. 单目相机标定引言相机标定已经研究多年,标定的算法可以分为基于摄影测量的标定和自标定。其中,应用最为广泛的还是张正友标定法。这是一种简单灵活、高鲁棒性、低成本的相机标定算法。仅需要一台相机和一块平面标定板构建相机标定系统,在标定过程中,相机拍摄多个角度下(至少两个角度,推荐10~20个角度)的标定板图像(相机和标定板都可以移动),即可对相机的内外参数进行标定。下面介绍张氏标定法(以下也这么称呼)的原理。原理相机模型和单应矩阵相机标定,就是对相机的内外参数进行计算的过程,从而得到物体到图像的投影_相机-投影仪标定

Wayland架构、渲染、硬件支持-程序员宅基地

文章浏览阅读2.2k次。文章目录Wayland 架构Wayland 渲染Wayland的 硬件支持简 述: 翻译一篇关于和 wayland 有关的技术文章, 其英文标题为Wayland Architecture .Wayland 架构若是想要更好的理解 Wayland 架构及其与 X (X11 or X Window System) 结构;一种很好的方法是将事件从输入设备就开始跟踪, 查看期间所有的屏幕上出现的变化。这就是我们现在对 X 的理解。 内核是从一个输入设备中获取一个事件,并通过 evdev 输入_wayland

推荐文章

热门文章

相关标签