Enterprise TensorFlow 2 – Saving a trained model

Part 2 in the series about Java / TensorFlow Interoperability, discussing how to save a model so it can be reused in a different environent

Note that the discussion below also applies to persisting a trained model for other environments than Java, for example TensorFlow Serving.

Reminder: What is under the hood?

Before we look at how to save stuff in TensorFlow and load it again in Java, first let’s give a quick refresher of what TensorFlow is built of so we know what actually needs saving and loading. If you feel you have a good understanding of how TensorFlow is organized under the hood, feel free to skip this section.

When we think of TensorFlow we often think about the Python source we write, but that code is actually just a tool to build a graph for the TensorFlow engine. A graph is a network of nodes and connections between those nodes. In TensorFlow this graph defines the algorithm we are using for training and inference, it defines how tensors flow from one operation (node) to another to achieve the desired effect. Hence the name TensorFlow.

So obviously we want to save the structure of the graph to run our inference (unless we re-implement the algorithm using other means than the TensorFlow engine).

Certain nodes in our graph contain state that has been learned, e.g. weights for neural networks. We need to restore that state in our graph to get the correct results, otherwise all the training would have been useless. This information we always need to store, regardless whether we use the TensorFlow engine for inference or instead a different framework or a hand-rolled solution.

Depending on the algorithm we use, there are certain hyperparameters that also need to be taken into account. Some hyperparameters, like the size and layers of a neural network will be implicitly part of our graph so we need not worry about them. Others will be saved in variables, like learned parameters and hence are saved like learning results. Other parameterizations however, (e.g. thresholds for classifcation) might have to be especially set before running the graph or are part of pre- or postprocessing results. Those need to be transferred manually. This is generally straightforward, as a simple constant in your Java code will generally do. Just rembember to go over all your hyperparameters and check if they are all considered in your inference code.

Sometimes a TensorFlow graph might contain additional assets, like dictionaries for natural language applications. For the sake of simplicity, we will ignore those in the following discussion.

We also need to make sure we know what operations in our graph to call for inference, so we need to know what our input and output nodes are. So, to sum it up, the state we want to persist contains:

  • Our graph structure
  • The learned parameters of our model (variables of the graph)
  • Maybe Hyperparameters (for the model or for pre- / postprocessing)
  • Sometimes additional assets
  • The information about input / output nodes (like the signature of a function)

Entities that make up our saved state

The TensorFlow API to save and load state has evolved quite a bit and different solutions can be found on the TensorFlow site, various blogs and StackOverflow. We will focus on the SavedModel, as this is the approach we need to conveniently load our model with the TensorFlow Java API. Other approaches for saving and loading use parts of the entities that make up the SavedModel, so you will be able to better understand other APIs and tutorials if you familiarize yourself with it. If you google the topic or even if you start with the official TensorFlow docs, the first approach for loading and saving is the Saver class. Internally, SavedModel uses a Saver for various aspects of its work, so the concepts are similar, however SavedModel seems to be the focus for future versions of TensorFlow, so I suggest you only familiarize yourself with the Saver if you really need it. You will not be able to load your data in Java if you just use a Saver!

So what is a SavedModel? A SavedModel is the root of the TensorFlow persistance hierarchy. It contains everything that needs saving. A SavedModel itself is not very complex, it is more or less just a map mapping “tags” (arbitrary string identifiers) to MetaGraphs and a number of “Checkpoints” (this is what TensorFlow calls the state of our variables).

The MetaGraph is the most important entity in our saved state. A MetaGraph is our familiar graph together with other meta information, like versioning information, assets and collections. The MetaGraph also contains the information about what nodes make up the input and which ones make up the output portion of our graph. TensorFlow calls this a MetaGraph’s signature (similar to method signatures in programming languages). A MetaGraph does not contain variable state!

So why can a SavedModel contain multible MetaGraphs? The idea here is that when we look at the complete graph we created, depending on what we do with it, we do not need all aspects of it. For example, when we just want to run inference, we do not need the parts of the graph necessary for learning. So SavedModel allows us to take multiple snapshots of subgraphs of our total graph, each for a different use case.

MetaGraph would be useless without the state of its variables (“Checkpoints”), so we need to make sure they are saved as well. The reason why Checkpoints are not stored with the MetaGraph is that multiple MetaGraphs might contain the same variable node if the individual sub graphs overlap. If the MetaGraph would contain the variable state as well, and we save multiple partly overlapping MetaGraphs, we would duplicate the variable content in our files. Depending on the size of the model, this might mean hundreds of megabytes (or even several gigabytes), so each variable’s state is only saved once, outside the MetaGraphs, even if referenced by more than one MetaGraph.

The following are the important API elements / concepts you need to know to save your model successfully:

  • SavedModel: the root of our hierarchy, contains multiple MetaGraphs, each with an individual name (tag).
  • MetaGraph: A (Sub-)Graph for one task, contains the graph structure, assets, and signature definitions for in-/output
  • Checkpoint: A snapshot for the variable state for all MetaGraphs in a SavedModel

File formats used by TensorFlow

Most entities that can be saved and loaded in TensorFlow are defined as Protobuf records. Protobuf is a language independent serialization format, similar to JSON or XML. The big difference is that Protobuf in it’s “native state” is a densely packed binary format, not a text format. While developing you can use Protobuf in a text format variant that looks very similar to JSON and is easier to debug, as it can be inspected with any text editor. Protobuf is used by Google for most of their distributed systems for communication.

Similar to a DTD or schema in XML, Protobuf entities are defined in external text files with the extension “.proto” (e.g. the definition of a whole graph). The “.proto” files are used by precompilers to build classes and (de-)serializers for various languages. For most languages there are well maintained protobuf precompilers, so loading and saving Protobuf data is generally pretty easy.

Serialized Protobuf records written to disk normally use the file extension “.pb” when they are in binary format or “.pbtxt” when they are in text format. To keep things interesting, TensorFlow sometimes uses these extensions, and sometimes it doesn’t.

You can tell which objects in the TensorFlow Python API are just autogenerated Protobuf Records by the suffix “Def”. So a tf.Graph is the actual graph you work with, a tf.GraphDef is a Protobuf record containing the information to restore a tf.Graph. Not every “Def” has a corresponding Python class, a MetaGraphDef has no corresponding MetaGraph class but is instead a collection of data containing a graph and additional information to run the model.

SavedModel on disc does not correspond to a single file but instead to a folder containing various files and subfolders with the saved state. The following is a list of files you will find in a SavedModel folder (as this API is still in flux, it is good to check the official docs from time to time:

  • assets/: contains auxiliary graph assets
  • assets.extra/: here you can put custom files that are not part of the standard format
  • variables/: contains your variable state (Checkpoints)
  • saved_model.pb: binary protobuf containing the SavedModel, the MetaGraphs, the Graphs, the signatures etc. pp.

The variables folder contains two types of files:

  • *.data: Content of variables. The filenames are numbered and there may be more than one file per checkpoint. The reason for this is that for very large neural networks the variable data can become quite huge – so if necessary TensorFlow splits it up when saving. This is a binary format, probably protobuf as well. This gets written by a node in the TensorFlow graph, so one has to dig deep to find more info about the actual format.
  • *.index: Used to map variable names to the variable data in the *.data files. This is necessary as it might not be clear in which of the split up files our variable data resides. As with the *.data format, this is an unknown binary format.

During training, various APIs (EstimatorMonitoredTrainingSession and others) save parts of a SavedModel. In this case, there are no nested folders and some additional files, some other files are missing:

  • checkpoint: a plain text file with the name of the latest Checkpoint. Using this TensorFlow identifies the checkpoint to load when resuming training.
  • *.meta: This is the Protobuf of a MetaGraphDef. This is part of what is otherwise stored in a saved_model.pb.
  • graph.pbtxt: Depending on how you invoke saving your model, a complete protobuf export of the graph is written to this file. This is not a SavedModel and not a MetaGraph, but just a Graph.

Note that TensorFlow provides command line utilities to inspect SavedModels. This is a highly useful application, I recommend familiarizing yourself with it if you regularly work with SavedModels.

Saving data during training and for production

There are various ways to create the above files from TensorFlow. Some of them seemingly deprecated. In this section I will focus on the approach I think is current best practice and allows us to get the results we need to use the Java TensorFlow wrapper library.

The whole subject of saving and loading in TensorFlow is an impressive example of the Big Ball o’ Mud Antipattern. It seems that again and a again parts have been grafted on here and there because previous methods lacked some crucial functionality. However this was done in different ways (global functions, object orientation) at different places in the code. This turns one of the simplest and most important tasks into a hotly discussed subject, as the vast amount of blog posts and StackOverflow questions show.

We will achieve the desired goal in the following way:

  1. We wrap our model code in a tf.estimator.Estimator (or use one of the ready-to-use Estimators in tf.estimator)
  2. The Estimator will take care of automagically saving the progress during training using MetaGraphs and Checkpoints (not SavedModels)
  3. When we are happy with the training result, we will export a SavedModelfor production / embedding in our Java environment

Note that the first result you will find when you search for saving stuff in TensorFlow is the tf.train.Saver class. This class is used by most of the various saving methods in TensorFlow for some parts of the state. However, I discourage its direct use, as everything in the TensorFlow docs seems to indicate the way to go for the future is the SavedModel. In our case we cannot use the Saver directly anyway, as the TensorFlow Java API needs a complete SavedModel to load, not just the parts created by the Saver. (SavedModel uses the Saver itself for checkpoints internally)

Using the tf.estimator.Estimator API

What is an Estimator?

The Estimator is the latest API suggested by TensorFlow for organizing your learning process. I say “suggest” because you can write your model completely without it, and indeed most examples you will find do not use Estimators (yet). The Estimator takes care of the following tasks:

  • Reading input using an input_fn
  • Defining the model
  • Training
  • Validating
  • Testing
  • Keeping track of state during training
  • Exporting

In a nutshell, the Estimator is a complete definition of a TensorFlow model, with all bells and whistles.

According to the docs Estimators are the recommended way to doing things in the future however, and they provide a couple of benefits for us:

  • by adhering to a common standard and best practice, programs become better understandable and comparable
  • Estimators take care of loading and saving during training for us
  • Estimators allow (relatively) easy exporting to SavedModels (this is our main reason for using them in these examples)
  • Often, you can just use a prepackaged Estimator and do not need to write your own Estimator at all
  • Estimators can be quickly exchanged for one another, allowing for quick comparison of models

Also, unlike many other aspects of the “scaffolding” part of TensorFlow, Estimators seem well designed (except for minor kinks here and there).

Using an Estimator

The following steps are necessary to use an Estimator:

  1. Define the input features for your Estimator using the tf.feature_column API.
  2. Define an input_fn to feed data to your Estimator for training and validation
  3. Use an instance of tf.estimator.RunConfig to control the training process (target folders, logging, etc.)
  4. Train your model using the Estimator.train method
  5. Validate your model on your validation set using the Estimator.validate method
  6. Export your model as a SavedModelusing Estimator.export_savedmodel

As the focus of this post is the saving of models for usage in a java environment, we will focus only on the last part. A complete example using Estimators, including defining data input, preprocessing, training and export to a SavedModel can be found here. Details about the other steps will be covered in a later post dealing with the intricacies of the Estimator. If you want to use a custom Estimator, you can write one easily yourself. It works like a prepackaged Estimator, you just need to define your own model_fn to define the TensorFlow graph for your model. Building custom Estimators will also be covered in a separate post, as exporting a SavedModelworks exactly the same for prepackaged and custom Estimators.

Saving using an Estimator

Saving a SavedModel from a trained Estimator requires just one method call, we just need to define some additional arguments for this call to succeed. The key to a successful export is specifying a correct serving_input_fn. This is similar to the input_fn needed for training and validating, it specifies how to provide data for the model when running inference. However, we need not specify the function manually ourselves, we just need to specify a map with input parameters and use another API call to have TensorFlow build the method for us.

Luckily, building this map is extremely simple: We just define an empty map and add placeholders with their respective names as key. So if your model needs three float input tensors named foobar and baz, you define the serving_input_fn like so:

feature_spec = {}
feature_spec['foo'] = tf.placeholder(tf.float32, shape=[None], name='foo')
feature_spec['bar'] = tf.placeholder(tf.float32, shape=[None], name='bar')
feature_spec['baz'] = tf.placeholder(tf.float32, shape=[None], name='baz')

serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_spec)

The shape of course depends on your input, in this case we only have single float values per case, so [None] is enough (None specifies that we have an unknown amount of input cases – that way we can run an arbitrary number of cases in one batch). For an 256×256 grayscale image the shape would be [256,256,None].

Now that we have the serving_input_fnwe just need to decide where to save and we are done:

regressor.export_savedmodel(
    # under this dir a subdir with the SavedModel is created
    export_dir_base = "saved_models",
    # this will define the part of the graph
    # that feeds data during inference
    serving_input_receiver_fn = serving_input_fn,
    # with this map you can populate a directory
    # with custom data - this can be used to add
    # additional data to the saved model, but you
    # have to load it manually
    assets_extra=None,
    # for debugging, you can export some protobuf
    # files as text
    as_text=False,
    # if you want a different checkpoint than the
    # last, you can pass a different checkpoint file
    # here, None always usess the last checkpoint
    checkpoint_path=None
)

That’s it. The export will have created a subdirectory under saved_models with a millisecond timestamp for a name. It is that directory we need to load from in our Java code.

Alternatives to the Estimator / SavedModel approach

As mentioned earlier, there are far too many ways too get state in and out of TensorFlow, so we focused on the most current and future-proof ones. Nevertheless there are two other approaches worth knowing for certain special cases.

Manual Saving of variables

The SavedModel approach is meant for cases where we want to load our model using TensorFlow again, be it a TensorFlow wrapper for a different language than Python or our standard Python TensorFlow code. As mentioned in the previous post however, there are cases where you use the trained state in a different framework or in hand-rolled code. In this case, we only care about the variable state. As the checkpoint format is rather opaque, we need more control over the format we save our variable data in.

Solving this problem amounts to just evaluating the variable nodes we are interested in and storing the result in a numpy array, then writing the data in a format of our choice manually (e.g. CSV). Here is a working example

Given how complicated the “right” way to save stuff is, this is painfully simple and straightforward to understand.

Graph Freezing

Graph freezing is a technique where the variables in a graph are replaced by constants containing the current variable state. As the value of constants is saved with the graph structure, not the variable values, this way we can export a complete model in a single file by just saving the graph and ignoring variables. As the API for loading TensorFlow models in Java automatically easily loads the variable state as well, we have no benefit using this method. The major benefit seems to be that you end up with only one file, which in my opinion is not worth the overhead and added project complexity – you can easily achieve the same by just zipping the exported files, for example.

Nevertheless, it is an interesting and elegant approach and might be beneficial in certain contexts, so I thought I might mention it here. If you are interested in how it is done, you can find a great tutorial here.

Summary

  • The loading / saving API for TensorFlow is regrettably rather complicated and confusing
  • The SavedModel approach is the way to go for future development and the approach we need for loading our model from Java
  • Recommended best practice: Wrap your model in an Estimator (or use a prepackaged estimator).
  • Let the Estimator take care of saving and loading during training automatically
  • When your are happy with your model, export your SavedModel from your Estimator
  • Do not use the Saver manually anymore, even if that is found in many tutorials and examples.