草庐IT

java - Tensorflow Java 多 GPU 推理

coder 2023-09-02 原文

我有一台带有多个 GPU 的服务器,我想在 Java 应用程序内的模型推理期间充分利用它们。 默认情况下,tensorflow 占用所有可用的 GPU,但仅使用第一个。

我可以想到三个选项来解决这个问题:

  1. 在进程级别限制设备可见性,即使用 CUDA_VISIBLE_DEVICES 环境变量。

    这将需要我运行 java 应用程序的多个实例并在它们之间分配流量。不是那种诱人的想法。

  2. 在单个应用程序中启动多个 session ,并尝试通过 ConfigProto 为每个 session 分配一个设备:

    public class DistributedPredictor {
    
        private Predictor[] nested;
        private int[] counters;
    
        // ...
    
        public DistributedPredictor(String modelPath, int numDevices, int numThreadsPerDevice) {
            nested = new Predictor[numDevices];
            counters = new int[numDevices];
    
            for (int i = 0; i < nested.length; i++) {
                nested[i] = new Predictor(modelPath, i, numDevices, numThreadsPerDevice);
            }
        }
    
        public Prediction predict(Data data) {
            int i = acquirePredictorIndex();
            Prediction result = nested[i].predict(data);
            releasePredictorIndex(i);
            return result;
        }
    
        private synchronized int acquirePredictorIndex() {
            int i = argmin(counters);
            counters[i] += 1;
            return i;
        }
    
        private synchronized void releasePredictorIndex(int i) {
            counters[i] -= 1;
        }
    }
    
    
    public class Predictor {
    
        private Session session;
    
        public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) {
    
            GPUOptions gpuOptions = GPUOptions.newBuilder()
                    .setVisibleDeviceList("" + deviceIdx)
                    .setAllowGrowth(true)
                    .build();
    
            ConfigProto config = ConfigProto.newBuilder()
                    .setGpuOptions(gpuOptions)
                    .setInterOpParallelismThreads(numDevices * numThreadsPerDevice)
                    .build();
    
            byte[] graphDef = Files.readAllBytes(Paths.get(modelPath));
            Graph graph = new Graph();
            graph.importGraphDef(graphDef);
    
            this.session = new Session(graph, config.toByteArray());
        }
    
        public Prediction predict(Data data) {
            // ...
        }
    }
    

    乍一看,这种方法似乎工作正常。但是, session 偶尔会忽略 setVisibleDeviceList 选项,并且所有这些都会针对第一个导致内存不足崩溃的设备。

  3. 使用 tf.device() 规范在 python 中以多塔方式构建模型。在 Java 端,在共享 session 中为不同的 Predictor 提供不同的塔。

    对我来说感觉很麻烦而且惯用错误。

更新:正如@ash 所建议的,还有另一种选择:

  1. 通过修改其定义 (graphDef) 为现有图的每个操作分配适当的设备。

    要完成它,可以修改方法 2 中的代码:

    public class Predictor {
    
        private Session session;
    
        public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) {
    
            byte[] graphDef = Files.readAllBytes(Paths.get(modelPath));
            graphDef = setGraphDefDevice(graphDef, deviceIdx)
    
            Graph graph = new Graph();
            graph.importGraphDef(graphDef);
    
            ConfigProto config = ConfigProto.newBuilder()
                    .setAllowSoftPlacement(true)
                    .build();
    
            this.session = new Session(graph, config.toByteArray());
        }
    
        private static byte[] setGraphDefDevice(byte[] graphDef, int deviceIdx) throws InvalidProtocolBufferException {
            String deviceString = String.format("/gpu:%d", deviceIdx);
    
            GraphDef.Builder builder = GraphDef.parseFrom(graphDef).toBuilder();
            for (int i = 0; i < builder.getNodeCount(); i++) {
                builder.getNodeBuilder(i).setDevice(deviceString);
            }
            return builder.build().toByteArray();
        }
    
        public Prediction predict(Data data) {
            // ...
        }
    }
    

    就像其他提到的方法一样,这个方法并没有让我从手动在设备之间分发数据中解脱出来。但至少它工作稳定并且比较容易实现。总的来说,这看起来像是一种(几乎)正常的技术。

有没有一种优雅的方法可以用 tensorflow java API 做这样的基本事情?任何想法,将不胜感激。

最佳答案

简而言之:有一种变通方法,您最终每个 GPU 有一个 session 。

详细信息:

一般流程是 TensorFlow 运行时尊重为图中的操作指定的设备。如果没有为操作指定设备,则它会根据一些试探法“放置”它。这些启发式方法目前导致“在 GPU 上进行操作:0 如果 GPU 可用并且有用于操作的 GPU 内核”(Placer::Run 如果您感兴趣的话)。

我认为你要求的是对 TensorFlow 的合理功能请求 - 能够将序列化图中的设备视为“虚拟”设备以在运行时映射到一组“物理”设备,或者设置“默认设备”。此功能当前不存在。将这样的选项添加到 ConfigProto 是您可能想要提交功能请求的内容。

我可以在此期间建议一个解决方法。首先,对您提出的解决方案进行一些评论。

  1. 您的第一个想法肯定会奏效,但正如您指出的那样,它很麻烦。

  2. ConfigProto 中使用 visible_device_list 进行设置并不完全可行,因为这实际上是每个进程的设置,在创建第一个 session 后会被忽略进行中。这当然没有按应有的方式记录(并且有点不幸,这出现在每 session 配置中)。但是,这解释了为什么您在此处的建议不起作用以及为什么您仍然看到正在使用单个 GPU。

  3. 这可行。

另一种选择是最终得到不同的图形(操作明确地放置在不同的 GPU 上),从而导致每个 GPU 一个 session 。像这样的东西可以用来编辑图形并为每个操作明确分配一个设备:

public static byte[] modifyGraphDef(byte[] graphDef, String device) throws Exception {
  GraphDef.Builder builder = GraphDef.parseFrom(graphDef).toBuilder();
  for (int i = 0; i < builder.getNodeCount(); ++i) {
    builder.getNodeBuilder(i).setDevice(device);
  }
  return builder.build().toByteArray();
} 

之后,您可以使用类似以下内容为每个 GPU 创建一个 GraphSession:

final int NUM_GPUS = 8;
// setAllowSoftPlacement: Just in case our device modifications were too aggressive
// (e.g., setting a GPU device on an operation that only has CPU kernels)
// setLogDevicePlacment: So we can see what happens.
byte[] config =
    ConfigProto.newBuilder()
        .setLogDevicePlacement(true)
        .setAllowSoftPlacement(true)
        .build()
        .toByteArray();
Graph graphs[] = new Graph[NUM_GPUS];
Session sessions[] = new Session[NUM_GPUS];
for (int i = 0; i < NUM_GPUS; ++i) {
  graphs[i] = new Graph();
  graphs[i].importGraphDef(modifyGraphDef(graphDef, String.format("/gpu:%d", i)));
  sessions[i] = new Session(graphs[i], config);    
}

然后使用sessions[i] 在 GPU #i 上执行图形。

希望对您有所帮助。

关于java - Tensorflow Java 多 GPU 推理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47799972/

有关java - Tensorflow Java 多 GPU 推理的更多相关文章

  1. java - 等价于 Java 中的 Ruby Hash - 2

    我真的很习惯使用Ruby编写以下代码:my_hash={}my_hash['test']=1Java中对应的数据结构是什么? 最佳答案 HashMapmap=newHashMap();map.put("test",1);我假设? 关于java-等价于Java中的RubyHash,我们在StackOverflow上找到一个类似的问题: https://stackoverflow.com/questions/22737685/

  2. java - 从 JRuby 调用 Java 类的问题 - 2

    我正在尝试使用boilerpipe来自JRuby。我看过guide从JRuby调用Java,并成功地将它与另一个Java包一起使用,但无法弄清楚为什么同样的东西不能用于boilerpipe。我正在尝试基本上从JRuby中执行与此Java等效的操作:URLurl=newURL("http://www.example.com/some-location/index.html");Stringtext=ArticleExtractor.INSTANCE.getText(url);在JRuby中试过这个:require'java'url=java.net.URL.new("http://www

  3. java - 我的模型类或其他类中应该有逻辑吗 - 2

    我只想对我一直在思考的这个问题有其他意见,例如我有classuser_controller和classuserclassUserattr_accessor:name,:usernameendclassUserController//dosomethingaboutanythingaboutusersend问题是我的User类中是否应该有逻辑user=User.newuser.do_something(user1)oritshouldbeuser_controller=UserController.newuser_controller.do_something(user1,user2)我

  4. java - 什么相当于 ruby​​ 的 rack 或 python 的 Java wsgi? - 2

    什么是ruby​​的rack或python的Java的wsgi?还有一个路由库。 最佳答案 来自Python标准PEP333:Bycontrast,althoughJavahasjustasmanywebapplicationframeworksavailable,Java's"servlet"APImakesitpossibleforapplicationswrittenwithanyJavawebapplicationframeworktoruninanywebserverthatsupportstheservletAPI.ht

  5. Observability:从零开始创建 Java 微服务并监控它 (二) - 2

    这篇文章是继上一篇文章“Observability:从零开始创建Java微服务并监控它(一)”的续篇。在上一篇文章中,我们讲述了如何创建一个Javaweb应用,并使用Filebeat来收集应用所生成的日志。在今天的文章中,我来详述如何收集应用的指标,使用APM来监控应用并监督web服务的在线情况。源码可以在地址 https://github.com/liu-xiao-guo/java_observability 进行下载。摄入指标指标被视为可以随时更改的时间点值。当前请求的数量可以改变任何毫秒。你可能有1000个请求的峰值,然后一切都回到一个请求。这也意味着这些指标可能不准确,你还想提取最小/

  6. 【Java 面试合集】HashMap中为什么引入红黑树,而不是AVL树呢 - 2

    HashMap中为什么引入红黑树,而不是AVL树呢1.概述开始学习这个知识点之前我们需要知道,在JDK1.8以及之前,针对HashMap有什么不同。JDK1.7的时候,HashMap的底层实现是数组+链表JDK1.8的时候,HashMap的底层实现是数组+链表+红黑树我们要思考一个问题,为什么要从链表转为红黑树呢。首先先让我们了解下链表有什么不好???2.链表上述的截图其实就是链表的结构,我们来看下链表的增删改查的时间复杂度增:因为链表不是线性结构,所以每次添加的时候,只需要移动一个节点,所以可以理解为复杂度是N(1)删:算法时间复杂度跟增保持一致查:既然是非线性结构,所以查询某一个节点的时候

  7. 【Java入门】使用Java实现文件夹的遍历 - 2

    遍历文件夹我们通常是使用递归进行操作,这种方式比较简单,也比较容易理解。本文为大家介绍另一种不使用递归的方式,由于没有使用递归,只用到了循环和集合,所以效率更高一些!一、使用递归遍历文件夹整体思路1、使用File封装初始目录,2、打印这个目录3、获取这个目录下所有的子文件和子目录的数组。4、遍历这个数组,取出每个File对象4-1、如果File是否是一个文件,打印4-2、否则就是一个目录,递归调用代码实现publicclassSearchFile{publicstaticvoidmain(String[]args){//初始目录Filedir=newFile("d:/Dev");Datebeg

  8. java - 为什么 ruby​​ modulo 与 java/other lang 不同? - 2

    我基本上来自Java背景并且努力理解Ruby中的模运算。(5%3)(-5%3)(5%-3)(-5%-3)Java中的上述操作产生,2个-22个-2但在Ruby中,相同的表达式会产生21个-1-2.Ruby在逻辑上有多擅长这个?模块操作在Ruby中是如何实现的?如果将同一个操作定义为一个web服务,两个服务如何匹配逻辑。 最佳答案 在Java中,模运算的结果与被除数的符号相同。在Ruby中,它与除数的符号相同。remainder()在Ruby中与被除数的符号相同。您可能还想引用modulooperation.

  9. java - Ruby 相当于 Java 的 Collections.unmodifiableList 和 Collections.unmodifiableMap - 2

    Java的Collections.unmodifiableList和Collections.unmodifiableMap在Ruby标准API中是否有等价物? 最佳答案 使用freeze应用程序接口(interface):Preventsfurthermodificationstoobj.ARuntimeErrorwillberaisedifmodificationisattempted.Thereisnowaytounfreezeafrozenobject.SeealsoObject#frozen?.Thismethodretur

  10. java - Java 的 StringReader 的 Ruby 等价物是什么? - 2

    在Java中,可以像这样从一个字符串创建一个IO流:Readerr=newStringReader("mytext");我希望能够在Ruby中做同样的事情,这样我就可以获取一个字符串并将其视为一个IO流。 最佳答案 r=StringIO.new("mytext")和here'sthedocumentation. 关于java-Java的StringReader的Ruby等价物是什么?,我们在StackOverflow上找到一个类似的问题: https://st

随机推荐