欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

mxnet通过模型的json文件和params文件导出模型的结构图

程序员文章站 2024-03-14 20:46:59
...

导读

有时候我们需要导出网络的结构图,来了解网络的结构网络的输入输出节点等信息

导出网络结构图

通过mxnet模型的json文件和params文件可以很容易的导出模型的结构图,代码如下

  • 下载模型的json文件和params文件

这里我们以ResNet-18网络结构为例,通过下面的代码先下载需要的文件

import mxnet as mx

def download_model():
  path = 'http://data.mxnet.io/models/imagenet/'
  [mx.test_utils.download(path + 'resnet/18-layers/resnet-18-0000.params'),
   mx.test_utils.download(path + 'resnet/18-layers/resnet-18-symbol.json'),
   mx.test_utils.download(path + 'synset.txt')]
  • 导出网络的结构图

这里默认将网络的结构保存为PDF文件,可以通过修改plot_network函数中的save_format参数来设置保存的格式

sym,arg_params,aux_params = mx.model.load_checkpoint("resnet-18",0)
a = mx.viz.plot_network(sym, shape={"data": (1, 3, 224, 224)}, node_attrs={"shape": 'rect', "fixedsize": 'false'})
a.render('resnet-18')
  • 网络结构图
    mxnet通过模型的json文件和params文件导出模型的结构图