Enterprise TensorFlow 3 – Loading a SavedModel in Java

Part 3 in the series about Java / TensorFlow Interoperability, showing how to load a TensorFlow SavedModel in Java

In the previous post we saved a trained model as a SavedModel. This left us with a folder containing various binary files containing the graph structure, learned parameters, optional assets and the signature (input and output tensors) for running inference on our model. To use this result now in a JVM environment, we need to load the model and execute it. Luckily, loading a model is far easier than saving it, so this post will be a lot shorter. Unlike saving in Python, there is only one API for loading in Java, not several, and it always works the same.

As with the previous post, you can find a complete working example on github.

Adding TensorFlow to a Java project

When discussing the pros and cons of using Java to run our trained models in a real world product we saw that TensorFlow is actually a C++ computing library with a Python frontend to build and run computing graphs. Using JNI this library can be used in Java just as well. Luckily, Google already did the grunt work and produced a slim wrapper around the TensorFlow library that provides everything we need to run any TensorFlow graph, as long as its saved as a SavedModel. The library is far from comfortable enough to do training and testing with it though, as all of the high level API found in Python is missing.

But as we are only interested in doing predictions, classifications – any kind of inference – with it, we have all we need. To start using the library, you just need the following maven dependency. The jar itself is on maven central and should be fetched automagically:

<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow</artifactId>
    <version>1.4.1</version>
</dependency>

That’s it. Now you can use TensorFlow in your JVM application. This dependency will work for any JVM language that can easily call Java libraries, so you can apply the methods below also to your Kotlin, Scala, Clojure etc. pp. program. The jar contains the necessary Java glue code and the native binaries for Windows, Mac OS X and Linux. You do not need a TensorFlow installation on the the machine you run the Java Code on, which makes it easy to use for Web Applications, where you may need to deploy to dozens of servers – no need to keep a TensorFlow installation up-to-date, everything is contained in the Java dependencies. The TensorFlow version packaged with the Java Jar is CPU only, but as we are not interested in training in Java, this should not be a problem. If you also need GPU support, you have to compile the JAR (and TensorFlow) yourself (you can find the instructions here). Before you go down that road I recommend checking if GPU based inference is actually faster. The GPU provides a huge speed advantage if you run a lot of batches through it and perform backpropagation for learning. If you just want predictions for single instances, the overhead for the communication between CPU, main memory and GPU might actually make things slower.

Loading a SavedModel

Now that we have added the right dependency, loading our SavedModel is very simple:

final SavedModelBundle bundle = SavedModelBundle.load(
    "/path/to/your/saved_model/subfolder", "serving");

This gives us a org.tensorflow.SavedModelBundleobject that we can use to run TensorFlow sessions. The object is threadsafe, so you can initialize it once and reuse it everywhere. The subfolder you have to provide for loading is the folder inside the folder we passed to our Estimator for saving. It should have an integer timestamp as a name, like 1513701267. The name of the subfolder is the epoch time in seconds when the SavedModel was saved, which is why it is different every time you save.

But what is the second parameter, "serving"? As mentioned in the previous article, a SavedModel can contain various MetaGraphs. The second parameter is a String vararg that identifies the MetaGraphs we want to use. You can load one or multiple MetaGraph in one go, the resulting SavedModelBundle will contain all the MetaGraphs you specified. If you only used the export method from the previous post and did not add multiple MetaGraphs to your saved model, your SavedModel will contain only one MetaGraph with the tag "serving". Otherwise you will have to use the tag(s) for loading your SavedModelBundle(s) that correspond to the MetaGraph(s) you want to load.

That’s it, now you can start using your trained model to run predictions from your JVM code. We will look at the details of running a TensorFlow session inside a JVM in the next post of this series.

Postscript: Determining the proper tag

If you use different API to save your SavedModel than in the Estimatorexample, you should have no problem figuring out the proper tag to load it again – you will have provided it when saving yourself. But maybe you have not trained and exported the model yourself and you need to find out which tags are available in a given SavedModel folder? In that case you can use the saved_model_cli command line application that ships with TensorFlow to examine the contents of a SavedModelfolder like so:

chris$ ~/Library/Python/3.6/bin/saved_model_cli show \
    --dir saved_models/1512127459/
The given SavedModel contains the following tag- sets:
serve

In this example we only have one tag, autogenerated by the Estimator. If you create your SavedModel differently, you may have more / different tags. It is one or more of those tags that you need to pass to your SavedModelBundle.load call.

Summary

To load a SavedModel in a JVM, do the following:

  • Add the TensorFlow build dependency
  • Determine the name of the folder where your SavedModel is stored (an integer timestamp)
  • Determine the tag(s) of the MetaGraph(s) you want to load (most likely "serving")
  • Load your SavedModelBundle using SavedModelBundle load(String exportDir, String... tags)
  • You have to do this only once per JVM instance, the SavedModelBundleis thread-safe