《tensorflow机器学习实战指南》的源码
https://github.com/nfmcclure/tensorflow_cookbook
全都是.ipynb 文件
pip install --upgrade pip
pip install jupyter notebook
jupyter notebook
参考官方文档
https://www.tensorflow.org/programmers_guide/saved_model?hl=zh-cn
save把变量都存下来
savetest.py
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-2)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# Save the variables to disk.
save_path = saver.save(sess, "./tmp/model.ckpt")
print("Model saved in path: %s" % save_path)
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
restore把存都变量都取出来
虽然定义来相同的 变量,restore都时候会覆盖掉
如果新程序便利多余save都程序,会报找不到变量都错误
storetest.py
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.reset_default_graph()
# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
#v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
#v3 = tf.get_variable("v3", shape=[7], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
#dec_v2 = v2.assign(v2-3)
#dec_v3 = v3.assign(v3-7)
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
sess.run(init_op)
#Do some work with the model.
inc_v1.op.run()
#dec_v2.op.run()
#dec_v3.op.run()
# Restore variables from disk.
saver.restore(sess, "./tmp/model.ckpt")
print("Model restored.")
# Check the values of the variables
print("v1 : %s" % v1.eval())
# print("v2 : %s" % v2.eval())
# print("v3 : %s" % v3.eval())
如果想看所有变量,可以打印
# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp
# print all tensors in checkpoint file
chkp.print_tensors_in_checkpoint_file("./tmp/model.ckpt", tensor_name='', all_tensors=True)
print("------")
# tensor_name: v1
# [ 1. 1. 1.]
# tensor_name: v2
# [-1. -1. -1. -1. -1.]
# print only tensor v1 in checkpoint file
chkp.print_tensors_in_checkpoint_file("./tmp/model.ckpt", tensor_name='v1', all_tensors=False)
print("------")
# tensor_name: v1
# tensor_name: v1
# [ 1. 1. 1.]
# print only tensor v2 in checkpoint file
chkp.print_tensors_in_checkpoint_file("./tmp/model.ckpt", tensor_name='v2', all_tensors=False)
# tensor_name: v2
# [-1. -1. -1. -1. -1.]
分享到:
相关推荐
今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事。 1. 用saver存取变量; 2. 用saver存取指定变量。 用saver存取变量。 话不多说,先上代码 # coding=utf-8 import os import tensorflow as tf ...
在TensorFlow中,Saver模块是用来保存和恢复模型的重要工具,尤其在模型训练过程中需要中断或者继续训练,或者需要在不同设备间迁移模型时显得尤为关键。本文将深入讲解如何使用TensorFlow Saver来保存和恢复指定的...
`TensorFlow Saver` 是一个内置模块,专门用于保存和加载模型的参数。本文将深入探讨如何使用`TensorFlow Saver`来处理`.ckpt`(checkpoint)文件。 首先,导入必要的库,包括`tensorflow`和`numpy`: ```python ...
Tensorflow针对这一需求提供了Saver类。 Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。 只要提供一个计数器...
Tensorflow 19 Saver保存读取.mp4
Tensorflow_19_Saver_保存读取_(神经网络_教学教程tutorial)
在TensorFlow中,基本的保存和载入模型的方法通常是使用tf.train.Saver类。这个类提供了两个主要方法:saver.save()用于保存模型,saver.restore()用于载入模型。 在保存模型时,首先需要创建一个Saver对象,然后在...
TensorFlow使用Saver类来保存和恢复变量。下面详细介绍TensorFlow中使用Saver类保存和恢复变量的方法。 首先,我们来看如何建立TensorFlow的变量并保存。在TensorFlow中,任何需要持久化的变量都应当被定义为...
下面将详细讲解TensorFlow训练模型的关键概念、流程及`Saver`的作用。 1. **TensorFlow框架**:TensorFlow是由Google开发的一款开源库,主要用于构建和执行计算图,它支持数据流图的创建,广泛应用于机器学习和深度...
TensorFlow 提供了 `Saver` 类来方便地实现这一功能。本篇将详细介绍如何利用 `Saver` 在 TensorFlow 中保存和提取模型参数。 首先,让我们了解 `Saver` 的基本用法。在训练模型时,我们通常会在每个训练周期结束后...
模型保存与恢复**: 使用`tf.train.Saver()`可以保存和恢复模型的参数,以便于模型的持续训练或部署。 **7. Keras API**: TensorFlow还集成了高级API Keras,它简化了模型构建、编译和训练的流程,适合快速原型设计...
- **Saver**:TensorFlow提供Saver类来保存和恢复模型,这对于模型的持久化和继续训练至关重要。 **6. 模型部署与 Serving** - **TensorFlow Serving**:这是一个用于生产环境的模型部署工具,它允许灵活地更新和...
2. **模型保存与加载**:TensorFlow提供 saver API 来保存和恢复模型的参数。 五、分布式训练 1. **多GPU训练**:TensorFlow支持在多GPU上并行计算,提高训练速度。 2. **分布式集群**:通过参数服务器架构,可以...
- **模型保存与恢复**:了解`tf.train.Saver`和`tf.saved_model`模块,掌握如何保存和恢复模型权重,以便于模型的迁移和继续训练。 - **分布式训练**:研究`tf.distribute` API,理解如何在多GPU、多机器上进行...