tensorflow从与训练网络模型中fine-tune部分网络层参数

方法1: var_names=tf.contrib.framework.list_variables("/Users/kylefan/hobotrl/log/AutoExposureDPGVGG/model.ckpt-0")2,手工制作一个ckpt文件:挨个对上一步中的变量赋值,然后tf.saver….保存下来这个新的ckpt,代替掉上一步的ckpt,并且修改checkpoint这个文件里的路径

方法1:
1,开启tf的训练,有ckpt生成时停止,使用以下语句获得相关层变量的全称:

var_names=tf.contrib.framework.list_variables("/Users/kylefan/hobotrl/log/AutoExposureDPGVGG/model.ckpt-0")

2,手工制作一个ckpt文件:挨个对上一步中的变量赋值,然后tf.saver….保存下来这个新的ckpt,代替掉上一步的ckpt,并且修改checkpoint这个文件里的路径

    with tf.Session() as sess:
        i=0
        for var_name, _ in tf.contrib.framework.list_variables("/Users/kylefan/hobotrl/log/AutoExposureDPGVGG/model.ckpt-0"):
            # Load the variable
            i += 1
            # if i < 10:
            if var_name.startswith('learn/se'):
                if not (var_name.endswith('Adam') or var_name.endswith('Adam_1')):
                    value_npz = ckpt2npz_name(var_name) # translate var_name from npz to ckpt to get corresponding value
                    var = tf.Variable(weights[value_npz], name=var_name)

        # Save the variables
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        checkpoint_path = "/Users/kylefan/hobotrl/log/AutoExposureDPGVGG/from_npz/model.ckpt-0"
        # print checkpoint.model_checkpoint_path
        saver.save(sess, checkpoint_path)

3,指定restore_var_list,保证tf不至于因为提供的参数过少报错,并且打印出fine-tune的变量值以确保正确:

restore_var_list = []
        for var in tf.global_variables():
            if ('learn/se' in var.name) and ('Adam' not in var.name):
                restore_var_list.append(var)
        with agent.create_session(config=config, save_dir=args.logdir, restore_var_list=restore_var_list) as sess:
            all_vars = tf.global_variables()
            with open(args.logdir + "/weight_fine_tuned.txt", "w") as f:
                for var in all_vars:
                    f.write("{}\n".format(var.name))
                    var_value = sess.run(var)
                    f.write("{}\n\n".format(var_value))

方法2:
若有下载好的ckpt文件,从ckpt文件中直接fine-tune

the .ckpt file is the old version output of saver.save(sess), which is the equivalent of your .ckpt-data (see below)

the “checkpoint” file is only here to tell some TF functions which is the latest checkpoint file.

.ckpt-meta contains the metagraph, i.e. the structure of your computation graph, without the values of the variables (basically what you can see in tensorboard/graph).

.ckpt-data contains the values for all the variables, without the structure. To restore a model in python, you’ll usually use the meta and data files with (but you can also use the .pb file):

saver = tf.train.import_meta_graph(path_to_ckpt_meta)
saver.restore(sess, path_to_ckpt_data)
I don’t know exactly for .ckpt-index, I guess it’s some kind of index needed internally to map the two previous files correctly. Anyway it’s not really necessary usually, you can restore a model with only .ckpt-meta and .ckpt-data.

the .pb file can save your whole graph (meta + data). To load and use (but not train) a graph in c++ you’ll usually use it, created with freeze_graph, which creates the .pb file from the meta and data. Be careful, (at least in previous TF versions and for some people) the py function provided by freeze_graph did not work properly, so you’d have to use the script version. Tensorflow also provides a tf.train.Saver.to_proto() method, but I don’t know what it does exactly.


作者:11744
原文链接:https://blog.csdn.net/fk1174/article/details/79731080

  • 发表于 2019-05-15 11:51
  • 阅读 ( 859 )
  • 分类:tensorflow

0 条评论

请先 登录 后评论
不写代码的码农
11744_csdn

0 篇文章

作家榜 »

  1. AI君 10 文章
  2. Tzung-Wen Liau 0 文章
  3. blairan 0 文章
  4. rookie 0 文章
  5. 陈凯 0 文章
  6. huanxue 0 文章
  7. admin 0 文章
  8. Lzs1998_csdn 0 文章