草庐IT

python - 有没有一种简单的方法可以在 tensorflow 中将 tf.data.Dataset.from_generator 中的特性与自定义 model_fn(Estimator) 结合使用

coder 2023-08-24 原文

我正在为我的训练数据使用 tensorflow 数据集 api,为 tf.data.Dataset.from_generator api 使用 input_fn 和生成器

def generator():
    ......
    yield { "x" : features }, label


def input_fn():
    ds = tf.data.Dataset.from_generator(generator, ......)
    ......
    feature, label = ds.make_one_shot_iterator().get_next()
    return feature, label

然后我使用如下代码为我的 Estimator 创建了一个自定义 model_fn:

def model_fn(features, labels, mode, params):
    print(features)
    ......
    layer = network.create_full_connect(input_tensor=features["x"], 
    (or layer = tf.layers.dense(features["x"], 200, ......)
    ......

训练时:

estimator.train(input_fn=input_fn)

但是,代码不起作用,因为函数 model_fn 的 features 参数是这样的:

Tensor("IteratorGetNext:0", dtype=float32, device=/device:CPU:0)

代码 "features["x"]"会失败并告诉我:

......"site-packages\tensorflow\python\ops\array_ops.py", line 504, in _SliceHelper end.append(s + 1) TypeError: must be str, not int

如果我将 input_fn 更改为:

input_fn = tf.estimator.inputs.numpy_input_fn(
  x={"x": np.array([[1,2,3,4,5,6]])},
  y=np.array([1]),

代码继续,因为 features 现在是一个字典。

我搜索了 estimator 的代码,发现它使用了一些函数,例如

features, labels = self._get_features_and_labels_from_input_fn(
      input_fn, model_fn_lib.ModeKeys.TRAIN)

从 input_fn 中检索特征和标签,但我不知道为什么它通过使用不同的数据集实现传递给我(model_fn)两种不同数据类型的特征,如果我想使用我的生成器模式,那么如何使用它类型(IteratorGetNext)的功能?

感谢您的帮助!

[更新]

我对代码做了一些修改,

def generator():
    ......
    yield features, label

def input_fn():
    ds = tf.data.Dataset.from_generator(generator, ......)
    ......
    feature, label = ds.make_one_shot_iterator().get_next()
    return {"x": feature}, label

然而,在 tf.layers.dense 仍然失败,现在它说

"Input 0 of layer dense_1 is incompatible with the layer: its rank is undefined, but the layer requires a defined rank."

虽然特征是一个字典:

'x': tf.Tensor 'IteratorGetNext:0' shape=unknown dtype=float64

在正确的情况下,它是:

'x': tf.Tensor 'random_shuffle_queue_DequeueMany:1' shape=(128, 6) dtype=float64

我从

学到了类似的用法

https://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html

def my_input_fn(file_path, perform_shuffle=False, repeat_count=1):
   def decode_csv(line):
      ......
      d = dict(zip(feature_names, features)), label
      return d

   dataset = (tf.data.TextLineDataset(file_path)

但是对于将迭代器返回到自定义 model_fn 的生成器情况,没有官方示例。

最佳答案

根据examples on how to use from_generator ,生成器返回要放入数据集中的,而不是特征字典。相反,您在 input_fn 中构建字典.

按如下方式更改代码应该可以正常工作:

def generator():
    ......
    yield features, label

def input_fn():
    ds = tf.data.Dataset.from_generator(generator, ......)
    ......
    feature, label = ds.make_one_shot_iterator().get_next()
    return {"x": feature}, label

回复更新:

您的代码失败是因为 Dataset.from_generator 的迭代器生成的张量没有静态 shape已定义(因为生成器原则上可以返回不同形状的数据)。 假设您的数据确实始终具有相同的形状,您可以调用 feature.set_shape(<the_shape_of_your_data>)之前return来自 input_fn (有关执行此操作的正确方法,请参阅编辑打击)。

编辑:

正如您在评论中指出的那样, tf.data.Dataset.from_generator() 有第三个参数设置输出张量的形状,所以不是feature.set_shape()只需将形状传递为 output_shapesfrom_generator() .

关于python - 有没有一种简单的方法可以在 tensorflow 中将 tf.data.Dataset.from_generator 中的特性与自定义 model_fn(Estimator) 结合使用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47390473/

有关python - 有没有一种简单的方法可以在 tensorflow 中将 tf.data.Dataset.from_generator 中的特性与自定义 model_fn(Estimator) 结合使用的更多相关文章

  1. ruby - 如何使用 Nokogiri 的 xpath 和 at_xpath 方法 - 2

    我正在学习如何使用Nokogiri,根据这段代码我遇到了一些问题:require'rubygems'require'mechanize'post_agent=WWW::Mechanize.newpost_page=post_agent.get('http://www.vbulletin.org/forum/showthread.php?t=230708')puts"\nabsolutepathwithtbodygivesnil"putspost_page.parser.xpath('/html/body/div/div/div/div/div/table/tbody/tr/td/div

  2. ruby - 如何从 ruby​​ 中的字符串运行任意对象方法? - 2

    总的来说,我对ruby​​还比较陌生,我正在为我正在创建的对象编写一些rspec测试用例。许多测试用例都非常基础,我只是想确保正确填充和返回值。我想知道是否有办法使用循环结构来执行此操作。不必为我要测试的每个方法都设置一个assertEquals。例如:describeitem,"TestingtheItem"doit"willhaveanullvaluetostart"doitem=Item.new#HereIcoulddotheitem.name.shouldbe_nil#thenIcoulddoitem.category.shouldbe_nilendend但我想要一些方法来使用

  3. ruby - 为什么我可以在 Ruby 中使用 Object#send 访问私有(private)/ protected 方法? - 2

    类classAprivatedeffooputs:fooendpublicdefbarputs:barendprivatedefzimputs:zimendprotecteddefdibputs:dibendendA的实例a=A.new测试a.foorescueputs:faila.barrescueputs:faila.zimrescueputs:faila.dibrescueputs:faila.gazrescueputs:fail测试输出failbarfailfailfail.发送测试[:foo,:bar,:zim,:dib,:gaz].each{|m|a.send(m)resc

  4. ruby-on-rails - 在 Rails 中将文件大小字符串转换为等效千字节 - 2

    我的目标是转换表单输入,例如“100兆字节”或“1GB”,并将其转换为我可以存储在数据库中的文件大小(以千字节为单位)。目前,我有这个:defquota_convert@regex=/([0-9]+)(.*)s/@sizes=%w{kilobytemegabytegigabyte}m=self.quota.match(@regex)if@sizes.include?m[2]eval("self.quota=#{m[1]}.#{m[2]}")endend这有效,但前提是输入是倍数(“gigabytes”,而不是“gigabyte”)并且由于使用了eval看起来疯狂不安全。所以,功能正常,

  5. ruby - Facter::Util::Uptime:Module 的未定义方法 get_uptime (NoMethodError) - 2

    我正在尝试设置一个puppet节点,但ruby​​gems似乎不正常。如果我通过它自己的二进制文件(/usr/lib/ruby/gems/1.8/gems/facter-1.5.8/bin/facter)在cli上运行facter,它工作正常,但如果我通过由ruby​​gems(/usr/bin/facter)安装的二进制文件,它抛出:/usr/lib/ruby/1.8/facter/uptime.rb:11:undefinedmethod`get_uptime'forFacter::Util::Uptime:Module(NoMethodError)from/usr/lib/ruby

  6. Ruby 方法() 方法 - 2

    我想了解Ruby方法methods()是如何工作的。我尝试使用“ruby方法”在Google上搜索,但这不是我需要的。我也看过ruby​​-doc.org,但我没有找到这种方法。你能详细解释一下它是如何工作的或者给我一个链接吗?更新我用methods()方法做了实验,得到了这样的结果:'labrat'代码classFirstdeffirst_instance_mymethodenddefself.first_class_mymethodendendclassSecond使用类#returnsavailablemethodslistforclassandancestorsputsSeco

  7. ruby - 难道Lua没有和Ruby的method_missing相媲美的东西吗? - 2

    我好像记得Lua有类似Ruby的method_missing的东西。还是我记错了? 最佳答案 表的metatable的__index和__newindex可以用于与Ruby的method_missing相同的效果。 关于ruby-难道Lua没有和Ruby的method_missing相媲美的东西吗?,我们在StackOverflow上找到一个类似的问题: https://stackoverflow.com/questions/7732154/

  8. ruby - 使用 Vim Rails,您可以创建一个新的迁移文件并一次性打开它吗? - 2

    使用带有Rails插件的vim,您可以创建一个迁移文件,然后一次性打开该文件吗?textmate也可以这样吗? 最佳答案 你可以使用rails.vim然后做类似的事情::Rgeneratemigratonadd_foo_to_bar插件将打开迁移生成的文件,这正是您想要的。我不能代表textmate。 关于ruby-使用VimRails,您可以创建一个新的迁移文件并一次性打开它吗?,我们在StackOverflow上找到一个类似的问题: https://sta

  9. ruby - 我可以使用 Ruby 从 CSV 中删除列吗? - 2

    查看Ruby的CSV库的文档,我非常确定这是可能且简单的。我只需要使用Ruby删除CSV文件的前三列,但我没有成功运行它。 最佳答案 csv_table=CSV.read(file_path_in,:headers=>true)csv_table.delete("header_name")csv_table.to_csv#=>ThenewCSVinstringformat检查CSV::Table文档:http://ruby-doc.org/stdlib-1.9.2/libdoc/csv/rdoc/CSV/Table.html

  10. ruby-on-rails - Rails 3.2.1 中 ActionMailer 中的未定义方法 'default_content_type=' - 2

    我在我的项目中添加了一个系统来重置用户密码并通过电子邮件将密码发送给他,以防他忘记密码。昨天它运行良好(当我实现它时)。当我今天尝试启动服务器时,出现以下错误。=>BootingWEBrick=>Rails3.2.1applicationstartingindevelopmentonhttp://0.0.0.0:3000=>Callwith-dtodetach=>Ctrl-CtoshutdownserverExiting/Users/vinayshenoy/.rvm/gems/ruby-1.9.3-p0/gems/actionmailer-3.2.1/lib/action_mailer

随机推荐