Enterprise TensorFlow 3 – Ein SavedModel in Java laden

3. Teil in der Serie über Java / TensorFlow Interoperabilität, der zeigt, wie man ein TensorFlow SavedModel in Java lädt.

Im vorherigen Beitrag haben wir ein trainiertes Modell als SavedModel gespeichert. Dies hinterließ uns einen Ordner mit verschiedenen Binärdateien, die die Graphenstruktur, erlernte Parameter, optionale Assets und die Signatur (Input- und Output-Tensoren) für die Ausführung der Inferenz auf unserem Modell enthalten. Um dieses Ergebnis nun in einer JVM-Umgebung nutzen zu können, müssen wir das Modell laden und ausführen. Zum Glück ist das Laden eines Modells viel einfacher als das Speichern, daher ist dieser Post viel kürzer. Im Gegensatz zum Speichern in Python gibt es nur eine API zum Laden in Java, nicht mehrere, und es funktioniert immer gleich.

Wie beim vorherigen Beitrag findet sich auch hier ein komplettes Beispiel auf github.

Hinzufügen von TensorFlow zu einem Java-Projekt

Als wir die Vor- und Nachteile der Verwendung von Java zur Ausführung unserer trainierten Modelle in einem realen Produkt diskutierten, sahen wir, dass TensorFlow eigentlich eine C++-Computing-Bibliothek mit einem Python-Frontend zum Erstellen und Ausführen von Computing-Graphen ist. Mit JNI kann diese Bibliothek auch in Java verwendet werden. Glücklicherweise hat Google bereits die Grundarbeit geleistet und einen schlanken Wrapper um die TensorFlow-Bibliothek erstellt, der alles bietet, was wir brauchen, um jeden TensorFlow-Graphen auszuführen, solange er als SavedModel gespeichert ist. Die Bibliothek ist jedoch bei weitem nicht komfortabel genug, um mit ihr zu trainieren und zu testen, da die gesamte High-Level-API in Python fehlt.

Aber da wir nur daran interessiert sind, Vorhersagen, Klassifizierungen – jede Art von Inferenz – damit zu machen, haben wir alles, was wir brauchen. Um die Bibliothek nutzen zu können, benötigen Sie lediglich die folgende Maven-Abhängigkeit. Das Jar ist über Maven Central verfügbar und sollte automagisch heruntergeladen werden:

<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow</artifactId>
    <version>1.4.1</version>
</dependency>

 

Das ist alles. Nun können Sie TensorFlow in Ihrer JVM-Anwendung verwenden. Diese Abhängigkeit funktioniert für jede JVM-Sprache, die leicht Java-Bibliotheken aufrufen kann, so dass Sie die unten aufgeführten Methoden auch auf Ihre Kotlin, Scala, Clojure usw. Projekte anwenden können. Das Jar enthält den notwendigen Java-Gluecode und die nativen Binärdateien für Windows, Mac OS X und Linux. Sie benötigen keine TensorFlow-Installation auf dem Rechner, auf dem Sie den Java-Code ausführen, was die Verwendung für Web-Anwendungen vereinfacht, bei denen Sie möglicherweise Dutzende von Servern im Einsatz haben.  Es ist nicht nötig, eine TensorFlow-Installation auf dem neuesten Stand zu halten, alles ist in den Java-Abhängigkeiten enthalten. Die TensorFlow Version, die mit dem Java Jar ausgeliefert wird, enthält nur CPU Anpassungen, aber da wir nicht an Training in Java interessiert sind, sollte dies kein Problem sein. Wenn Sie auch GPU-Unterstützung benötigen, müssen Sie das Jar (und TensorFlow) selbst kompilieren (die Anleitung finden Sie hier). Bevor Sie diesen Weg gehen, empfehle ich zu prüfen, ob die GPU-basierte Inferenz tatsächlich schneller ist. Die GPU bietet einen enormen Geschwindigkeitsvorteil, wenn Sie viele Batches durchlaufen und Backpropagation zum Lernen durchführen. Wenn Sie nur Vorhersagen für einzelne Datenpunkte wünschen, kann der Overhead für die Kommunikation zwischen CPU, Hauptspeicher und GPU die Sache langsamer machen, als eine Vorhersage nur mittels CPU.

Laden eines gespeicherten Modells

Nun, da wir die richtige Abhängigkeit hinzugefügt haben, ist das Laden unseres SavedModels sehr einfach:

final SavedModelBundle bundle = SavedModelBundle.load(
"/path/to/your/saved_model/subfolder", "serving");

Dies gibt uns ein org.tensorflow.SavedModelBundle-Objekt, mit dem wir TensorFlow-Sitzungen ausführen können. Das Objekt ist threadsafe, so dass Sie es einmal initialisieren und überall wiederverwenden können. Der Unterordner, den Sie zum Laden angeben müssen, ist der Ordner innerhalb des Ordners, den wir unserem Estimator zum Speichern übergeben haben. Es sollte einen ganzzahligen Zeitstempel als Namen haben, z.B. 1513701267. Der Name des Unterordners ist die Zeit seit Epoch in Sekunden, zu der das SavedModel gespeichert wurde, weshalb es bei jedem Speichern anders ist.

Aber was ist der zweite Parameter, "serving"? Wie im vorherigen Post erwähnt, kann ein SavedModel verschiedene MetaGraphen enthalten. Der zweite Parameter ist ein String vararg, der die zu verwendenden MetaGraphen identifiziert. Sie können einen oder mehrere MetaGraphen auf einmal laden, das resultierende SavedModelBundle enthält alle von Ihnen angegebenen MetaGraphen. Wenn Sie nur die Exportmethode des vorherigen Posts verwendet haben und nicht mehrere MetaGraphen zu Ihrem gespeicherten Modell hinzugefügt haben, enthält Ihr SavedModel nur einen MetaGraph mit dem Tag "serving". Andernfalls müssen Sie die Tag(s) zum Laden Ihres SavedModelBundle(s) verwenden, die den zu ladenden MetaGraphen entsprechen.

Das war’s, jetzt können Sie Ihr trainiertes Modell verwenden, um Vorhersagen aus Ihrem JVM-Code auszuführen. Wir werden uns die Details der Durchführung einer TensorFlow-Sitzung innerhalb einer JVM im nächsten Beitrag dieser Serie ansehen.

Postscript: Bestimmung des richtigen Tags

Wenn Sie eine andere API verwenden, um Ihr SavedModelBundle zu speichern als im Estimator-Beispiel, sollten Sie kein Problem damit haben, das richtige Tag zu finden, um es wieder zu laden – Sie werden es beim Speichern selbst angegeben haben. Aber vielleicht haben Sie das Modell nicht selbst trainiert und exportiert und müssen herausfinden, welche Tags in einem bestimmten SavedModel-Ordner verfügbar sind? In diesem Fall können Sie die Kommandozeilenanwendung saved_model_cli, die mit TensorFlow ausgeliefert wird, verwenden, um den Inhalt eines SavedModel-Ordners zu untersuchen:

chris$ ~/Library/Python/3.6/bin/saved_model_cli show \
--dir saved_models/1512127459/
The given SavedModel contains the following tag- sets:
serve

In diesem Beispiel haben wir nur einen Tag, der vom Estimator automatisch generiert wird. Wenn Sie Ihr SavedModel anders erstellen, können Sie mehr / andere Tags haben. Es ist einer oder mehrerer dieser Tags, die Sie an Ihren  SavedModelBundle.load -Aufruf übergeben müssen.

Zusammenfassung

Um ein SavedModel in eine JVM zu laden, gehen Sie wie folgt vor:

  • Hinzufügen der TensorFlow-Build-Abhängigkeit
  • Bestimmen Sie den Namen des Ordners, in dem Ihr SavedModel gespeichert ist (ein ganzzahliger Zeitstempel).
  • Bestimmen Sie die Tags der MetaGraphen, die Sie laden möchten (höchstwahrscheinlich"serving").
    Laden Sie Ihr SavedModelBundle mit SavedModelBundle load(String exportDir, String... tags)
  • Sie müssen dies nur einmal pro JVM-Instanz tun, das SavedModelBundle ist thread-safe