Enterprise TensorFlow 4 – Eine TensorFlow Session in Java ausführen

Eine TensorFlow Session wird in Java genauso ausgeführt wie in Python. Dieser Post zeigt wie.

Im vorigen Post wurde gezeigt was nötig ist, um ein TensorFlow SavedModel in Java zu laden. Jetzt ist es an der Zeit, Ergebnisse mit dem Modell zu erzeugen. Zum Glück ist das Vorgehen auf der JVM hier genau das gleiche wie bei Low-Level-TensorFlow in Python. Alles, was wir tun müssen, ist, die Knoten zu identifizieren, die die Ein- und Ausgabe unseres Computing-Graphen definieren, Daten in Tensoren zu verpacken und in einer Session die Berechnung durchzuführen.

Wie bei den vorherigen Beiträgen ist der komplette Code auf github zu finden.

Das Ausführen einer TensorFlow-Session in Java erfordert die folgenden Schritte:

  1. Eingabedaten in Tensorobjekte verpacken, mit statischen Helfern der Tensors-Klasse.
  2. Session-Objekt holen
  3. Erstellen eines Runner-Objekts für die Sitzung
  4. Eingangs-Tensoren mit Runner.feed den richtigen Knoten im Graphen zuordnen
  5. Definieren der Ausgabe, die mit Runner.fetch zurückgegeben werden soll.
  6. Ausführen der Berechnung mit Runner.run
  7. Entpacken der Ergebnis Tensoren mit einer der Utility-Methoden der Tensors-Klasse oder einem copyTo-Aufruf.
  8. Nicht vergessen, alle Resourcen (Tensorobjekte) zu schließen.

Das sieht vielleicht ziemlich beängstigend aus, ist aber in der Praxis sehr einfach, dank einer gut dokumentierten API, vielen Hilfsfunktionen und einem netten fluent-Interface.

Daten inorg.tensorflow.Tensor-Objekte einpacken

Die Tensor-Klasse ist die wichtigste Klasse beim Einsatz des TensorFlow Java Wrappers. Sie wird verwendet, um Daten zu verpacken und zu entpacken, um sie der TensorFlow-Engine zuzuführen und Ergebnisse zu erhalten. Der komplizierteste Teil der Ausführung unseres Modells in Java ist das korrekte Ein- und Auspacken unserer Daten. Zum Glück, sind im Fehlerfall die daraus resultierenden Fehlermeldungen sehr aussagekräftig und ausführlich, so dass dies normalerweise ein einfacher Job ist.

In 99% aller Anwendungsfälle reicht einfach eine der Hilfsmethoden in der Tensors-Klasse, um einen Tensor mit der richtigen Form, dem richtigen Datentyp und dem richtigen Inhalt zu erstellen. Hier ist ein Beispiel für das verpacken eines einzelnen Input-Fließkommawertes:

final Tensor<Float> t = Tensors.create(f);

Es gibt Hilfsmethoden für alle Datentypen und bis zu sechs Dimensionen, daher sollten dort alles zu finden sein, was man braucht. Wie wir später sehen werden, ist es dennoch sinnvoll, das erzeugen von Tensor-Objekten in r-Erstellungen in eigene Funktionsaufrufe zu packen, um das Ressourcenhandling ein wenig zu vereinfachen – es muss sichergestellt werden .close() auf allen erstellten Tensoren aufgerufen wird!

In sehr, sehr seltenen Fällen müssen Sie auf die Tensor-Klasse selbst zurückgreifen – so können Sie einen Tensor jeder beliebigen Form erstellen. Der Vollständigkeit halber sei hier ein Beispiel für das manuelle Verpacken eines float in einen Tensor genannt (Finger weg von dieser Methode, es sei denn, es ist unbedingt nötig):

final Tensor<Float> t = Tensor.create(
    new long[] {1}, // the shape
    FloatBuffer.wrap(new float[] {f}) // the data
); 

Ausführen einer Sitzung und Output der Ergebnisse

Wie in der Python Low-Level-API wird ein Modell in einer Sitzung ausgeführt. Um ein Handle auf ein Session-Objekt zu bekommen, rufen wir einfach die Methode SavedModelBundle.session() auf. Das Session-Objekt wird wiederum verwendet, um einen Runner zu erhalten. Der Runner bietet eine fluent API, die verwendet wird, um Tensoren an Knoten im Graphen mit Runner.feed zu binden und zu definieren, welche Tensoren nach der Berechnung mit Runner.fetch zurückgegeben werden sollen. Die fluent API funktioniert wie ein Builder, jeder Aufruf gibt wieder den Runner zurück, so dass wir Aufrufe verketten können. Wenn alles verkabelt ist, rufen wir Runner.run() auf, um die Berechnung durchzuführen und das Ergebnis zurückzugeben. Das Ergebnis ist eine Liste von Tensoren, die Anzahl der Elemente in der Liste hängt von der Anzahl der Runner.fetch-Aufrufe ab, jeder Aufruf erzeugt einen zusätzlichen Listeneintrag. Dies ist ein vollständiges Beispiel, das alle Aufrufe zu einer langen Anweisung verkettet:

final Tensor<?> result = 
    // gets the session
    bundle.session() 
    // creates a runner
    .runner() 
    // binds tensors to input nodes in the graph, in our case 
    // `values` is an array of floats, toTensor creates a Tensor
    // object, the first argument is a string with the name of 
    // the input node
    .feed("wine_type"           , toTensor(values[1], tensorsToClose))
    .feed("fixed_acidity"       , toTensor(values[2], tensorsToClose))
    .feed("volatile_acidity"    , toTensor(values[3], tensorsToClose))
    .feed("citric_acid"         , toTensor(values[4], tensorsToClose))
    .feed("residual_sugar"      , toTensor(values[5], tensorsToClose))
    .feed("chlorides"           , toTensor(values[6], tensorsToClose))
    .feed("free_sulfur_dioxide" , toTensor(values[7], tensorsToClose))
    .feed("total_sulfur_dioxide", toTensor(values[8], tensorsToClose))
    .feed("density"             , toTensor(values[9], tensorsToClose))
    .feed("ph"                  , toTensor(values[10], tensorsToClose))
    .feed("sulphates"           , toTensor(values[11], tensorsToClose))
    .feed("alcohol"             , toTensor(values[12], tensorsToClose))
    // define which output tensor to return
    // (you can chain multiple `fetch` calls to 
    // return more then one tensor)
    .fetch("dnn/head/logits:0")
    // execute the runner - this returns a list
    .run()
    // We have only one fetch call, so we get a 
    // one-element-list. The `get(0)` call fetches
    // the first element of the list
    .get(0);

Auspacken der resultierenden Tensoren

Was jetzt noch bleibt, ist, das Ergebnis aus dem Tensor zu holen, der vom run()-Aufruf zurückgegeben wird. Wenn der Ergebnis-Tensor nur ein Skalar ist, kann dafür einfach Tensor.floatValue(), Tensor.booleanValue() usw. aufgerufen werden. Wenn der resultierende Tensor kein Skalar ist, müssen die resultierenden Daten mit Tensor.copyTo(U destination) abgerufen werden, wobei destination ein multidimensionales Array ist. Eine vorgefertigte Regression mit neuronalen Netzen gibt beispielsweise immer einen zweidimensionalen Tensor zurück, auch wenn en nur ein einziges numerisches Ergebnis gibt. In diesem Fall kommt man wie folgt an das Ergebnis:

float[][] resultValues = (float[][]) result.copyTo(new float[1][1]);
float prediction = resultValues[0][0];  

Der Type und die Anzahl der Dimensionen des Arrays hängen vom jeweiligen Modell ab.

Ressourcenverwaltung

Zwei Arten von Objekten müssen manuell geschlossen werden: Sessions und Tensoren. Beachten Sie, dass alle Tensor-Objekte – ob manuell erstellt oder aus einer laufenden Sitzung zurückgegeben – manuell geschlossen werden müssen. Ich ziehe es vor, dies zu tun, indem ich Tensor-Erzeuging immer in einer Hilfsfunktion ausführe, die alle erstellten Tensoren in einer Collection sammelt um dann alles in einem finally Block freizugeben, wenn ich fertig bin:

 

private static Tensor<Float> toTensor(final float f, 
        final Collection<Tensor<?>> tensorsToClose) 
{
    final Tensor<Float> t = Tensors.create(f);
    if (tensorsToClose != null) {
        tensorsToClose.add(t);
    }
    return t;
}       

private static void closeTensors(final Collection<Tensor<?>> ts) {      
    for (final Tensor<?> t : ts) {
        try {
            t.close();
        } catch (final Exception e) {
            // TODO: decide on the error handling best fitting your use case here
            // In most cases logging is the only useful thing left to do
            System.err.println("Error closing Tensor.");
            e.printStackTrace();
        }
    }
    ts.clear();
}

private void runSession(final float foo, /* more params here */) {
    final List<Tensor<?>> tensorsToClose = new ArrayList<Tensor<?>>(); 
    try {            
        // run session
        final List<Tensor<?>> result = bundle.session().runner()
            .feed("foo", toTensor(foo, tensorsToClose))
            // ... feed more tensors as necessary ...
            .fetch("some_node")
            // ... fetch more tensors as necessary ...
            .run(); 
        // mark result for cleanup
        tensorsToClose.addAddAll(result);
        // ... do something with the result ...
    } finally {
        closeTensors(tensorsToClose);
    }
}

Hierbei muss man darauf achten, auf der Session nicht close()aufzurufen! Die Session wird einmalig für das SavedModelBundle erstellt, der session()Aufruf gibt eine bestehende Referenz zurück, keine neue Session. Die Session ist thread-safe, so dass sie überall wiederverwendet werden kann. Sie muss erst geschlossen werden, wenn keine weiteren Aufrufe mehr folgen. Daher sollten die Session nur am Ende des Programms oder beim Herunterfahren eines Servers beendet werden. Sie können dies einfach tun, indem Sie Ihr SavedModelBundle schließen, das alle mit dem SavedModel verbundenen Ressourcen freigibt. (Sie können sogar das Schließen der Sitzung auslassen, da das Ende Ihres JVM-Prozesses ohnehin alle damit verbundenen Ressourcen schließen sollte – ich hatte nie irgendwelche negativen Auswirkungen, aber tun Sie dies auf eigenes Risiko!)

Die richtigen Namen für Ein- und Ausgabeknoten bestimmen

Wenn Sie einen eigenen Estimator geschrieben haben, wissen Sie wahrscheinlich wie Ihre Ein- und Ausgabeknoten heißen und welche Form sie haben. Manchmal haben Sie jedoch einen vorgefertigten Estimator verwendet, bei dem Sie nicht wissen, wie die Ausgabeknoten aufgerufen werden, oder Sie haben das Modell nicht selbst geschrieben und müssen die gespeicherten Daten überprüfen, um zu wissen, was Sie aufrufen sollen. In diesem Fall müssen Sie Ihr SavedModel auf der Kommandozeile überprüfen, um Ihr Tag, Ihre Ein- und Ausgabeknotennamen, deren Formen und Datentypen zu bestimmen. Dies kann durch aufeinanderfolgende Aufrufe des saved_model_cli geschehen (wir verwenden hier das SavedModel aus unserem Beispielprojekt, Ihre Ausgabe hängt natürlich vom verwendeten Modell ab):

 

chris$ ~/Library/Python/3.6/bin/saved_model_cli show \ 
--dir saved_models/1512127459/
The given SavedModel contains the following tag- sets:
serve

chris$ ~/Library/Python/3.6/bin/saved_model_cli show \
--dir saved_models/1512127459/ \
--tag_set serve
The given SavedModel MetaGraphDef contains
SignatureDefs with the following keys:
SignatureDef key: "predict"

chris$ ~/Library/Python/3.6/bin/saved_model_cli show\
--dir saved_models/1512127459/ \
--tag_set serve \
--signature_def predict
The given SavedModel SignatureDef contains the following input(s):
inputs['wine_type'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: wine_type:0
...
The given SavedModel SignatureDef contains the
following output(s):
outputs['predictions'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: dnn/head/logits:0
Method name is: tensorflow/serving/predict

Wie Sie sehen können, benötigen wir aufeinanderfolgende Aufrufe von saved_model_cli show, um tiefer in unserem SavedModel vorzudringen, um die Namen, Formen und Datentypen der Ein- und Ausgangstensoren zu bestimmen. Leider können diese Informationen nicht generisch mit der Java-API (AFAIK) abgerufen werden. Die Art und Form des Ergebnis-Tensors zumindest kann aber auch durch Aufrufe von Tensor.shape() und Tensor.dataType() überprüft werden.

Zusammenfassung

Eine TensorFlow-Sitzung in Java auszuführen ist ziemlich einfach. Hier noch einmal eine Checkliste:

  • SavedModelBundle und die zugehörige Session sind thread-safe, Tensoren nicht
  • Verwenden Sie das saved_model_cli, um die Namen und die Form Ihrer Ein- und Ausgabeknoten zu bestimmen.
  • Verpacken Sie Ihre Eingabedaten mit den Hilfsmethoden in der Klasse Tensors
  • Verwenden Sie die fluent-API des SavedModelBundleBundle, um eine Session zu erhalten und auszuführen: bundle.session().runner().feed(...).fetch(...).run()
  • Für skalare Ergebnisse: Verwenden Sie Tensor.floatValue() usw., um Daten aus den resultierenden Tensoren zu holen.
  • Für nicht skalare Ergebnisse: Verwenden Sie den richtigen Array-Typ und die richtige Form, um Daten aus Ihren resultierenden Tensoren mit Tensor.copyTo() zu holen.
  • Schließen Sie das SavedModelBundle erst dann, wenn Sie Ihre JVM beenden möchten, z.B. beim Herunterfahren des Servers.