Enterprise TensorFlow - Saving a trained model
Part 2 in the series about Java / TensorFlow Interoperability, discussing how to save a model so it can be reused in a different environment.
Note that the discussion below also applies to persisting a trained model for other environments than Java, for example TensorFlow Serving.
Reminder: What is under the hood?
Before we look at how to save stuff in TensorFlow and load it again in Java, first let’s give a quick refresher of what TensorFlow is built of so we know what actually needs saving and loading. If you feel you have a good understanding of how TensorFlow is organized under the hood, feel free to skip this section.
When we think of TensorFlow we often think about the Python source we write, but that code is actually just a tool to build a graph for the TensorFlow engine. A graph is a network of nodes and connections between those nodes. In TensorFlow this graph defines the algorithm we are using for training and inference, it defines how tensors flow from one operation (node) to another to achieve the desired effect. Hence the name TensorFlow.
So obviously we want to save the structure of the graph to run our inference (unless we re-implement the algorithm using other means than the TensorFlow engine).
Certain nodes in our graph contain state that has been learned, e.g. weights for neural networks. We need to restore that state in our graph to get the correct results, otherwise all the training would have been useless. This information we always need to store, regardless whether we use the TensorFlow engine for inference or instead a different framework or a hand-rolled solution.
Depending on the algorithm we use, there are certain hyperparameters that also need to be taken into account. Some hyperparameters, like the size and layers of a neural network will be implicitly part of our graph so we need not worry about them. Others will be saved in variables, like learned parameters and hence are saved like learning results. Other parameterizations however, (e.g. thresholds for classifcation) might have to be especially set before running the graph or are part of pre- or postprocessing results. Those need to be transferred manually. This is generally straightforward, as a simple constant in your Java code will generally do. Just rembember to go over all your hyperparameters and check if they are all considered in your inference code.
Sometimes a TensorFlow graph might contain additional assets, like dictionaries for natural language applications. For the sake of simplicity, we will ignore those in the following discussion.
We also need to make sure we know what operations in our graph to call for inference, so we need to know what our input and output nodes are. So, to sum it up, the state we want to persist contains:
- Our graph structure
- The learned parameters of our model (variables of the graph)
- Maybe Hyperparameters (for the model or for pre- / postprocessing)
- Sometimes additional assets
- The information about input / output nodes (like the signature of a function)
Entities that make up our saved state
The TensorFlow API to save and load state has evolved quite a bit and different solutions can be found on the TensorFlow site, various blogs and StackOverflow. We will focus on the SavedModel
, as this is the approach we need to conveniently load our model with the TensorFlow Java API. Other approaches for saving and loading use parts of the entities that make up the SavedModel
, so you will be able to better understand other APIs and tutorials if you familiarize yourself with it. If you google the topic or even if you start with the official TensorFlow docs, the first approach for loading and saving is the Saver
class. Internally, SavedModel
uses a Saver
for various aspects of its work, so the concepts are similar, however SavedModel
seems to be the focus for future versions of TensorFlow, so I suggest you only familiarize yourself with the Saver
if you really need it. You will not be able to load your data in Java if you just use a Saver
!
So what is a SavedModel
? A SavedModel
is the root of the TensorFlow persistance hierarchy. It contains everything that needs saving. A SavedModel
itself is not very complex, it is more or less just a map mapping “tags” (arbitrary string identifiers) to MetaGraph
s and a number of “Checkpoints” (this is what TensorFlow calls the state of our variables).
The MetaGraph
is the most important entity in our saved state. A MetaGraph
is our familiar graph together with other meta information, like versioning information, assets and collections. The MetaGraph
also contains the information about what nodes make up the input and which ones make up the output portion of our graph. TensorFlow calls this a MetaGraph
’s signature (similar to method signatures in programming languages). A MetaGraph
does not contain variable state!
So why can a SavedModel
contain multible MetaGraph
s? The idea here is that when we look at the complete graph we created, depending on what we do with it, we do not need all aspects of it. For example, when we just want to run inference, we do not need the parts of the graph necessary for learning. So SavedModel
allows us to take multiple snapshots of subgraphs of our total graph, each for a different use case.
A MetaGraph
would be useless without the state of its variables (“Checkpoints”), so we need to make sure they are saved as well. The reason why Checkpoints are not stored with the MetaGraph
is that multiple MetaGraph
s might contain the same variable node if the individual sub graphs overlap. If the MetaGraph
would contain the variable state as well, and we save multiple partly overlapping MetaGraphs
, we would duplicate the variable content in our files. Depending on the size of the model, this might mean hundreds of megabytes (or even several gigabytes), so each variable’s state is only saved once, outside the MetaGraph
s, even if referenced by more than one MetaGraph
.
The following are the important API elements / concepts you need to know to save your model successfully:
SavedModel
: the root of our hierarchy, contains multipleMetaGraph
s, each with an individual name (tag).MetaGraph
: A (Sub-)Graph for one task, contains the graph structure, assets, and signature definitions for in-/output- Checkpoint: A snapshot for the variable state for all
MetaGraph
s in aSavedModel
File formats used by TensorFlow
Most entities that can be saved and loaded in TensorFlow are defined as Protobuf records. Protobuf is a language independent serialization format, similar to JSON or XML. The big difference is that Protobuf in it’s “native state” is a densely packed binary format, not a text format. While developing you can use Protobuf in a text format variant that looks very similar to JSON and is easier to debug, as it can be inspected with any text editor. Protobuf is used by Google for most of their distributed systems for communication.
Similar to a DTD or schema in XML, Protobuf entities are defined in external text files with the extension “.proto” (e.g. the definition of a whole graph). The “.proto” files are used by precompilers to build classes and (de-)serializers for various languages. For most languages there are well maintained protobuf precompilers, so loading and saving Protobuf data is generally pretty easy.
Serialized Protobuf records written to disk normally use the file extension “.pb” when they are in binary format or “.pbtxt” when they are in text format. To keep things interesting, TensorFlow sometimes uses these extensions, and sometimes it doesn’t.
You can tell which objects in the TensorFlow Python API are just autogenerated Protobuf Records by the suffix “Def”. So a tf.Graph is the actual graph you work with, a tf.GraphDef
is a Protobuf record containing the information to restore a tf.Graph
. Not every “Def” has a corresponding Python class, a MetaGraphDef
has no corresponding MetaGraph
class but is instead a collection of data containing a graph and additional information to run the model.
A SavedModel
on disc does not correspond to a single file but instead to a folder containing various files and subfolders with the saved state. The following is a list of files you will find in a SavedModel
folder (as this API is still in flux, it is good to check the official docs from time to time:
assets/
: contains auxiliary graph assetsassets.extra/
: here you can put custom files that are not part of the standard formatvariables/
: contains your variable state (Checkpoints)saved_model.pb
: binary protobuf containing theSavedModel
, theMetaGraph
s, theGraph
s, the signatures etc. pp.
The variables
folder contains two types of files:
*.data
: Content of variables. The filenames are numbered and there may be more than one file per checkpoint. The reason for this is that for very large neural networks the variable data can become quite huge - so if necessary TensorFlow splits it up when saving. This is a binary format, probably protobuf as well. This gets written by a node in the TensorFlow graph, so one has to dig deep to find more info about the actual format.*.index
: Used to map variable names to the variable data in the*.data
files. This is necessary as it might not be clear in which of the split up files our variable data resides. As with the*.data
format, this is an unknown binary format.
During training, various APIs (Estimator
, MonitoredTrainingSession
and others) save parts of a SavedModel
. In this case, there are no nested folders and some additional files, some other files are missing:
checkpoint
: a plain text file with the name of the latest Checkpoint. Using this TensorFlow identifies the checkpoint to load when resuming training.*.meta
: This is the Protobuf of aMetaGraphDef
. This is part of what is otherwise stored in asaved_model.pb
.graph.pbtxt
: Depending on how you invoke saving your model, a complete protobuf export of the graph is written to this file. This is not aSavedModel
and not aMetaGraph
, but just aGraph
.
Note that TensorFlow provides command line utilities to inspect SavedModels. This is a highly useful application, I recommend familiarizing yourself with it if you regularly work with SavedModel
s.
Saving data during training and for production
There are various ways to create the above files from TensorFlow. Some of them seemingly deprecated. In this section I will focus on the approach I think is current best practice and allows us to get the results we need to use the Java TensorFlow wrapper library.
The whole subject of saving and loading in TensorFlow is an impressive example of the Big Ball o’ Mud Antipattern. It seems that again and a again parts have been grafted on here and there because previous methods lacked some crucial functionality. However this was done in different ways (global functions, object orientation) at different places in the code. This turns one of the simplest and most important tasks into a hotly discussed subject, as the vast amount of blog posts and StackOverflow questions show.
We will achieve the desired goal in the following way:
- We wrap our model code in a
tf.estimator.Estimator
(or use one of the ready-to-useEstimator
s intf.estimator
) - The
Estimator
will take care of automagically saving the progress during training usingMetaGraph
s and Checkpoints (notSavedModel
s) - When we are happy with the training result, we will export a
SavedModel
for production / embedding in our Java environment
Note that the first result you will find when you search for saving stuff in TensorFlow is the tf.train.Saver
class. This class is used by most of the various saving methods in TensorFlow for some parts of the state. However, I discourage its direct use, as everything in the TensorFlow docs seems to indicate the way to go for the future is the SavedModel
. In our case we cannot use the Saver
directly anyway, as the TensorFlow Java API needs a complete SavedModel
to load, not just the parts created by the Saver
. (SavedModel
uses the Saver itself for checkpoints internally)
Using the tf.estimator.Estimator
API
What is an Estimator
?
The Estimator
is the latest API suggested by TensorFlow for organizing your learning process. I say “suggest” because you can write your model completely without it, and indeed most examples you will find do not use Estimators (yet). The Estimator
takes care of the following tasks:
- Reading input using an
input_fn
- Defining the model
- Training
- Validating
- Testing
- Keeping track of state during training
- Exporting
In a nutshell, the Estimator
is a complete definition of a TensorFlow model, with all bells and whistles.
According to the docs Estimator
s are the recommended way to doing things in the future however, and they provide a couple of benefits for us:
- by adhering to a common standard and best practice, programs become better understandable and comparable
Estimator
s take care of loading and saving during training for usEstimator
s allow (relatively) easy exporting toSavedModel
s (this is our main reason for using them in these examples)- Often, you can just use a prepackaged
Estimator
and do not need to write your ownEstimator
at all Estimator
s can be quickly exchanged for one another, allowing for quick comparison of models
Also, unlike many other aspects of the “scaffolding” part of TensorFlow, Estimator
s seem well designed (except for minor kinks here and there).
Using an Estimator
The following steps are necessary to use an Estimator
:
- Define the input features for your
Estimator
using thetf.feature_column
API. - Define an
input_fn
to feed data to yourEstimator
for training and validation - Use an instance of
tf.estimator.RunConfig
to control the training process (target folders, logging, etc.) - Train your model using the
Estimator.train
method - Validate your model on your validation set using the
Estimator.validate
method - Export your model as a
SavedModel
usingEstimator.export_savedmodel
As the focus of this post is the saving of models for usage in a java environment, we will focus only on the last part. A complete example using Estimator
s, including defining data input, preprocessing, training and export to a SavedModel
can be found here. Details about the other steps will be covered in a later post dealing with the intricacies of the Estimator
. If you want to use a custom Estimator
, you can write one easily yourself. It works like a prepackaged Estimator
, you just need to define your own model_fn
to define the TensorFlow graph for your model. Building custom Estimator
s will also be covered in a separate post, as exporting a SavedModel
works exactly the same for prepackaged and custom Estimator
s.
Saving using an Estimator
Saving a SavedModel
from a trained Estimator
requires just one method call, we just need to define some additional arguments for this call to succeed. The key to a successful export is specifying a correct serving_input_fn
. This is similar to the input_fn
needed for training and validating, it specifies how to provide data for the model when running inference. However, we need not specify the function manually ourselves, we just need to specify a map with input parameters and use another API call to have TensorFlow build the method for us.
Luckily, building this map is extremely simple: We just define an empty map and add placeholders with their respective names as key. So if your model needs three float input tensors named foo
, bar
and baz
, you define the serving_input_fn
like so:
feature_spec = {}
feature_spec['foo'] = tf.placeholder(tf.float32, shape=[None], name='foo')
feature_spec['bar'] = tf.placeholder(tf.float32, shape=[None], name='bar')
feature_spec['baz'] = tf.placeholder(tf.float32, shape=[None], name='baz')
serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_spec)
The shape of course depends on your input, in this case we only have single float
values per case, so [None]
is enough (None
specifies that we have an unknown amount of input cases - that way we can run an arbitrary number of cases in one batch). For an 256x256 grayscale image the shape
would be [256,256,None]
.
Now that we have the serving_input_fn
we just need to decide where to save and we are done:
regressor.export_savedmodel(
# under this dir a subdir with the SavedModel is created
export_dir_base = "saved_models",
# this will define the part of the graph
# that feeds data during inference
serving_input_receiver_fn = serving_input_fn,
# with this map you can populate a directory
# with custom data - this can be used to add
# additional data to the saved model, but you
# have to load it manually
assets_extra=None,
# for debugging, you can export some protobuf
# files as text
as_text=False,
# if you want a different checkpoint than the
# last, you can pass a different checkpoint file
# here, None always usess the last checkpoint
checkpoint_path=None
)
That’s it. The export will have created a subdirectory under saved_models
with a millisecond timestamp for a name. It is that directory we need to load from in our Java code.
Alternatives to the Estimator / SavedModel approach
As mentioned earlier, there are far too many ways too get state in and out of TensorFlow, so we focused on the most current and future-proof ones. Nevertheless there are two other approaches worth knowing for certain special cases.
Manual Saving of variables
The SavedModel approach is meant for cases where we want to load our model using TensorFlow again, be it a TensorFlow wrapper for a different language than Python or our standard Python TensorFlow code. As mentioned in the previous post however, there are cases where you use the trained state in a different framework or in hand-rolled code. In this case, we only care about the variable state. As the checkpoint format is rather opaque, we need more control over the format we save our variable data in.
Solving this problem amounts to just evaluating the variable nodes we are interested in and storing the result in a numpy array, then writing the data in a format of our choice manually (e.g. CSV).
Given how complicated the “right” way to save stuff is, this is painfully simple and straightforward to understand.
Graph Freezing
Graph freezing is a technique where the variables in a graph are replaced by constants containing the current variable state. As the value of constants is saved with the graph structure, not the variable values, this way we can export a complete model in a single file by just saving the graph and ignoring variables. As the API for loading TensorFlow models in Java automatically easily loads the variable state as well, we have no benefit using this method. The major benefit seems to be that you end up with only one file, which in my opinion is not worth the overhead and added project complexity - you can easily achieve the same by just zipping the exported files, for example.
Nevertheless, it is an interesting and elegant approach and might be beneficial in certain contexts, so I thought I might mention it here. If you are interested in how it is done, you can find a great tutorial here.
Summary
- The loading / saving API for TensorFlow is regrettably rather complicated and confusing
- The SavedModel approach is the way to go for future development and the approach we need for loading our model from Java
- Recommended best practice: Wrap your model in an Estimator (or use a prepackaged estimator).
- Let the Estimator take care of saving and loading during training automatically
- When your are happy with your model, export your SavedModel from your Estimator
- Do not use the Saver manually anymore, even if that is found in many tutorials and examples.