Enterprise TensorFlow 4 – Executing a TensorFlow Session in Java

A TensorFlow Session can be executed in Java in the same way as in Python. This post shows how.

We have managed to load a TensorFlow SavedModel in Java. Now it is time to get results out of the model. Luckily, the idiom for this is just the same as in low-level TensorFlow. All we need to do is identify the nodes that define the input and output of our computing graph, wrap data in Tensors and run a session.

As with the previous posts, you can find the complete working code on github.

Running a TensorFlow session in Java requires the following steps:

  1. Wrap your input data in Tensorobjects using the static helpers in the Tensors class
  2. Get a Session object
  3. Create a Runner object for the session
  4. Assign the input Tensors to the proper nodes in your graph with Runner.feed
  5. Define the output you want returned with Runner.fetch
  6. Execute the Computation with Runner.run
  7. Unwrap the result Tensors using one of the Tensor’s convenience methods or a copyTo call
  8. Make sure you close all Tensorobjects

This might look quite daunting, but is very simple in practice, thanks to a well documented API, lots of helper functions and a nice fluent Interface.

Wrapping data in org.tensorflow.Tensor objects

The Tensor class is the most important class when using the TensorFlow Java Wrapper. It is used to wrap and unwrap data to feed it to the TensorFlow engine and get results back. The most complicated part of running our model in java is correctly wrapping and unwrapping our data. Luckily, if we do something wrong, the resulting error messages are very meaningful and verbose, so this is normally an easy job.

In 99% of all use cases, you can simply call one of the helper methods in the Tensorsclass to create a Tensor of the proper shape, data type and content. Here is an example of wrapping a single input floatvalue:

final Tensor<Float> t = Tensors.create(f);

There are helper methods for all data types and up to six dimensions, so you should find everything you need there. As we will sell later, you still may want to wrap all Tensor creation in function calls of your own to make resource handling a bit easier – you must make sure to call .close() on all created Tensors!

In very, very rare cases you may have to resort to create calls on the Tensor class itself – this will allow you to create a Tensorof any shape. For completeness sake, here is an example of manually wrapping a float in a Tensor (do not do this unless you absolutely have to):

final Tensor<Float> t = Tensor.create(
    new long[] {1}, // the shape
    FloatBuffer.wrap(new float[] {f}) // the data
); 

Running a session and retrieving results

As in the Python low-level API, a model is executed in a session. To get a handle to a Session object, we just call the SavedModelBundle.session() method. The Session object is in turn used to get a Runner. The Runner provides a fluent API that is used to bind Tensors to nodes in the graph with Runner.feed and to define which Tensors to return after the computation is complete with Runner.fetch. The fluent API works like a Builder, each call again returns the Runnerso we can chain calls. When everything is wired, we call Runner.run() to perform the computation and return the result. The result is a list of Tensors, the number of elements in the list depends on the number of Runner.fetch calls, each call will create an additional List entry. This is a complete result chaining all calls into one long statement:

final Tensor<?> result = 
    // gets the session
    bundle.session() 
    // creates a runner
    .runner() 
    // binds tensors to input nodes in the graph, in our case 
    // `values` is an array of floats, toTensor creates a Tensor
    // object, the first argument is a string with the name of 
    // the input node
    .feed("wine_type"           , toTensor(values[1], tensorsToClose))
    .feed("fixed_acidity"       , toTensor(values[2], tensorsToClose))
    .feed("volatile_acidity"    , toTensor(values[3], tensorsToClose))
    .feed("citric_acid"         , toTensor(values[4], tensorsToClose))
    .feed("residual_sugar"      , toTensor(values[5], tensorsToClose))
    .feed("chlorides"           , toTensor(values[6], tensorsToClose))
    .feed("free_sulfur_dioxide" , toTensor(values[7], tensorsToClose))
    .feed("total_sulfur_dioxide", toTensor(values[8], tensorsToClose))
    .feed("density"             , toTensor(values[9], tensorsToClose))
    .feed("ph"                  , toTensor(values[10], tensorsToClose))
    .feed("sulphates"           , toTensor(values[11], tensorsToClose))
    .feed("alcohol"             , toTensor(values[12], tensorsToClose))
    // define which output tensor to return
    // (you can chain multiple `fetch` calls to 
    // return more then one tensor)
    .fetch("dnn/head/logits:0")
    // execute the runner - this returns a list
    .run()
    // We have only one fetch call, so we get a 
    // one-element-list. The `get(0)` call fetches
    // the first element of the list
    .get(0);

Unwrapping resulting Tensors

What is left now is to get the result out of the Tensor returned by the run() call. If the result Tensor is simply a scalar, you can just call Tensor.floatValue()Tensor.booleanValue() etc. If the resulting tensor is not a scalar, the resulting data needs to be retrieved with Tensor.copyTo(U destination), where destination is a multidimensional array. Prepackaged neural network regression estimators for example always return a two dimensional tensor, even if you only have one single numerical result. In that case, you can retrieve the result like this:

float[][] resultValues = (float[][]) result.copyTo(new float[1][1]);
float prediction = resultValues[0][0];  

The type and number of dimensions of the array depends on your model.

Resource management

Two types of objects will need manual closing for proper resource handling: Sessions and Tensors. Note that allTensor objects – whether created manually or returned from running a session – must be closed manually. I prefer to do this by performing all Tensor creation in helper functions that collect all created Tensors in a Collection and then free everything in a finally block after I am done:

private static Tensor<Float> toTensor(final float f, 
        final Collection<Tensor<?>> tensorsToClose) 
{
    final Tensor<Float> t = Tensors.create(f);
    if (tensorsToClose != null) {
        tensorsToClose.add(t);
    }
    return t;
}       

private static void closeTensors(final Collection<Tensor<?>> ts) {      
    for (final Tensor<?> t : ts) {
        try {
            t.close();
        } catch (final Exception e) {
            // TODO: decide on the error handling best fitting your use case here
            // In most cases logging is the only useful thing left to do
            System.err.println("Error closing Tensor.");
            e.printStackTrace();
        }
    }
    ts.clear();
}

private void runSession(final float foo, /* more params here */) {
    final List<Tensor<?>> tensorsToClose = new ArrayList<Tensor<?>>(); 
    try {            
        // run session
        final List<Tensor<?>> result = bundle.session().runner()
            .feed("foo", toTensor(foo, tensorsToClose))
            // ... feed more tensors as necessary ...
            .fetch("some_node")
            // ... fetch more tensors as necessary ...
            .run(); 
        // mark result for cleanup
        tensorsToClose.addAddAll(result);
        // ... do something with the result ...
    } finally {
        closeTensors(tensorsToClose);
    }
}

Note the absence of Session closing: The session is created once for the SavedModelBundle, the session() call returns an existing reference, not a new session. The Session is thread-safe, so it can be reused everywhere. It only needs closing when you are completely done. So you should close the Session only at the end of your program or when you shut down your server. You can simply do this by closing your SavedModelBundle, which frees all resources associated with the SavedModel. (You may even omit closing the Session, as the end of your JVM process should free all resources associated with it anyway – I never had any negative effects, but do this at your own risk!)

Determining the proper names for input and output nodes

If you have written your own Estimatoryou probably know how the input and output nodes that you need to call, as well as their shapes. Sometimes however, you might have used a prepackaged Estimator, where you do not know how the output nodes are called or you may not have written the model yourself and need to inspect the saved data to know what to call. In that case, like for the tag necessary to load the necessary MetaGraphs in the previous post, you need to inspect your SavedModel on the command line to determine your tag, your input and output node names, their shapes and data types. This can be done by successive calls to the saved_model_cli like this (we use the SavedModel from our example project here, your output will obviously depend on the model you use):

chris$ ~/Library/Python/3.6/bin/saved_model_cli show \ 
--dir saved_models/1512127459/
The given SavedModel contains the following tag- sets:
serve

chris$ ~/Library/Python/3.6/bin/saved_model_cli show \
--dir saved_models/1512127459/ \
--tag_set serve
The given SavedModel MetaGraphDef contains
SignatureDefs with the following keys:
SignatureDef key: "predict"

chris$ ~/Library/Python/3.6/bin/saved_model_cli show\
--dir saved_models/1512127459/ \
--tag_set serve \
--signature_def predict
The given SavedModel SignatureDef contains the following input(s):
inputs['wine_type'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: wine_type:0
...
The given SavedModel SignatureDef contains the
following output(s):
outputs['predictions'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: dnn/head/logits:0
Method name is: tensorflow/serving/predict

As you can see, we need successive calls to saved_model_cli show to “dig deeper” into our SavedModel to determine the names, shapes and datatypes of the input and output tensors. Regrettably, this information cannot be retrieved generically with the Java API (AFAIK). The type and shape of result Tensors however can also be examined by calls to Tensor.shape() and Tensor.dataType().

Summary

Running a TensorFlow session in Java is pretty easy, just remember:

  • SavedModelBundle and the corresponding Session are thread-safe, Tensors are not
  • Use the saved_model_cli to determine the name and shape of your input and output nodes
  • Wrap your input data with the helper methods in the Tensors class
  • Use the fluent API on the SavedModelBundle to aquire and run a session: bundle.session().runner().feed(...).fetch(...).run()
  • For scalar results: Use Tensor.floatValue() etc. to retrieve data from the resulting Tensors.
  • For non-scalar results: Use the proper array type and shape to retrieve data from your resulting Tensors using Tensor.copyTo
  • Only call close on your SavedModelBundle when you are completely done and want to shut down your JVM, e.g. on server shutdown