草庐IT

python - 我的神经网络实现有什么问题?

coder 2023-05-24 原文

我想绘制神经网络相对于训练示例数量的学习误差曲线。这是代码:

import sklearn
import numpy as np
from sklearn.model_selection import learning_curve
import matplotlib.pyplot as plt
from sklearn import neural_network
from sklearn import cross_validation

myList=[]
myList2=[]
w=[]

dataset=np.loadtxt("data", delimiter=",")
X=dataset[:, 0:6]
Y=dataset[:,6]
clf=sklearn.neural_network.MLPClassifier(hidden_layer_sizes=(2,3),activation='tanh')

# split the data between training and testing
X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.25, random_state=33)

# begin with few training datas
X_eff=X_train[0:int(len(X_train)/150), : ]
Y_eff=Y_train[0:int(len(Y_train)/150)]

k=int(len(X_train)/150)-1


for m in range (140) :


    print (m)

    w.append(k)

    # train the model and store the training error
    A=clf.fit(X_eff,Y_eff)
    myList.append(1-A.score(X_eff,Y_eff))

      # compute the testing error
    myList2.append(1-A.score(X_test,Y_test))

    # add some more training datas
    X_eff=np.vstack((X_eff,X_train[k+1:k+101,:]))
    Y_eff=np.hstack((Y_eff,Y_train[k+1:k+101]))
    k=k+100

plt.figure(figsize=(8, 8))
plt.subplots_adjust()
plt.title("Erreur d'entrainement et de test")
plt.plot(w,myList,label="training error")
plt.plot(w,myList2,label="test error")
plt.legend()
plt.show()

但是,我得到了一个非常奇怪的结果,曲线波动,训练误差非常接近测试误差,这似乎不正常。 错误在哪里?我不明白为什么会有这么多起伏,为什么训练错误没有像预期的那样增加。任何帮助将不胜感激!

编辑:我使用的数据集是 https://archive.ics.uci.edu/ml/datasets/Chess+%28King-Rook+vs.+King%29我摆脱了少于 1000 个实例的类。我手动重新编码了乱码数据。

最佳答案

我认为您看到这种曲线的原因是您测量的性能指标与您优化的性能指标不同。

优化指标

神经网络最小化损失函数,在 tanh 激活的情况下,我假设您使用的是交叉熵损失的修改版本。如果您要绘制随时间变化的损失,您会看到一个更单调递减的误差函数,如您所料。 (实际上并不是单调的,因为神经网络是非凸的,但这不是重点。)

性能指标

您测量的性能指标是准确度百分比,它不同于损失。为什么这些不同?损失函数以可微分的方式告诉我们有多少误差(这对于快速优化方法很重要)。准确度指标告诉我们预测的好坏,这对于神经网络的应用很有用。

把它放在一起

因为您正在绘制相关指标的性能,您可以预期该图看起来与您的优化指标相似。但是,由于它们不相同,您可能会在您的情节中引入一些无法解释的差异(正如您发布的情节所证明的那样)。

有几种方法可以解决此问题。

  1. 绘制损失而不是准确度。如果您确实需要准确度图,这实际上并不能解决您的问题,但它会为您提供更平滑的曲线。
  2. 绘制多次运行的平均值。保存算法的 20 次独立运行的准确度图(如训练网络 20 次),然后将它们平均在一起并绘制它。这将大大减少方差。

TL;DR

不要期望准确度图总是平滑且单调递减,它不会。


问题编辑后:

现在您已经添加了数据集,我看到了其他一些可能导致您遇到的问题的事情。

量级信息

数据集定义了几个棋子的等级和文件(行和列)。这些是作为从 1 到 6 的整数输入的。但是 2 真的 1 比 1 好吗? 6真的4比2好吗?就棋位而言,我认为情况并非如此。

想象一下,我正在构建一个以金钱为输入的分类器。我的值(value)观是否刻画了一些信息?是的,1 美元与 100 美元完全不同;我们可以根据大小判断出存在关系。

对于国际象棋游戏,第 1 行的含义是否与第 8 行不同?一点也不,事实上这些尺寸是对称的!在您的网络中使用偏置单元可以通过“重新调整”您的输入以有效地从 [-3, 4] (现在以 0 为中心(ish)左右)来帮助解释对称性。

解决方案

但是,我认为,您可以通过平铺编码或一次性编码您的每个功能获得最大的 yield 。不要让网络依赖于每个特征的量级中包含的信息,因为这可能会导致网络进入糟糕的局部最优状态。

关于python - 我的神经网络实现有什么问题?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45919315/

有关python - 我的神经网络实现有什么问题?的更多相关文章

  1. ruby - 为什么我可以在 Ruby 中使用 Object#send 访问私有(private)/ protected 方法? - 2

    类classAprivatedeffooputs:fooendpublicdefbarputs:barendprivatedefzimputs:zimendprotecteddefdibputs:dibendendA的实例a=A.new测试a.foorescueputs:faila.barrescueputs:faila.zimrescueputs:faila.dibrescueputs:faila.gazrescueputs:fail测试输出failbarfailfailfail.发送测试[:foo,:bar,:zim,:dib,:gaz].each{|m|a.send(m)resc

  2. python - 如何使用 Ruby 或 Python 创建一系列高音调和低音调的蜂鸣声? - 2

    关闭。这个问题是opinion-based.它目前不接受答案。想要改进这个问题?更新问题,以便editingthispost可以用事实和引用来回答它.关闭4年前。Improvethisquestion我想在固定时间创建一系列低音和高音调的哔哔声。例如:在150毫秒时发出高音调的蜂鸣声在151毫秒时发出低音调的蜂鸣声200毫秒时发出低音调的蜂鸣声250毫秒的高音调蜂鸣声有没有办法在Ruby或Python中做到这一点?我真的不在乎输出编码是什么(.wav、.mp3、.ogg等等),但我确实想创建一个输出文件。

  3. ruby-on-rails - Rails - 子类化模型的设计模式是什么? - 2

    我有一个模型:classItem项目有一个属性“商店”基于存储的值,我希望Item对象对特定方法具有不同的行为。Rails中是否有针对此的通用设计模式?如果方法中没有大的if-else语句,这是如何干净利落地完成的? 最佳答案 通常通过Single-TableInheritance. 关于ruby-on-rails-Rails-子类化模型的设计模式是什么?,我们在StackOverflow上找到一个类似的问题: https://stackoverflow.co

  4. ruby - 在 64 位 Snow Leopard 上使用 rvm、postgres 9.0、ruby 1.9.2-p136 安装 pg gem 时出现问题 - 2

    我想为Heroku构建一个Rails3应用程序。他们使用Postgres作为他们的数据库,所以我通过MacPorts安装了postgres9.0。现在我需要一个postgresgem并且共识是出于性能原因你想要pggem。但是我对我得到的错误感到非常困惑当我尝试在rvm下通过geminstall安装pg时。我已经非常明确地指定了所有postgres目录的位置可以找到但仍然无法完成安装:$envARCHFLAGS='-archx86_64'geminstallpg--\--with-pg-config=/opt/local/var/db/postgresql90/defaultdb/po

  5. ruby - 什么是填充的 Base64 编码字符串以及如何在 ruby​​ 中生成它们? - 2

    我正在使用的第三方API的文档状态:"[O]urAPIonlyacceptspaddedBase64encodedstrings."什么是“填充的Base64编码字符串”以及如何在Ruby中生成它们。下面的代码是我第一次尝试创建转换为Base64的JSON格式数据。xa=Base64.encode64(a.to_json) 最佳答案 他们说的padding其实就是Base64本身的一部分。它是末尾的“=”和“==”。Base64将3个字节的数据包编码为4个编码字符。所以如果你的输入数据有长度n和n%3=1=>"=="末尾用于填充n%

  6. ruby - 解析 RDFa、微数据等的最佳方式是什么,使用统一的模式/词汇(例如 schema.org)存储和显示信息 - 2

    我主要使用Ruby来执行此操作,但到目前为止我的攻击计划如下:使用gemsrdf、rdf-rdfa和rdf-microdata或mida来解析给定任何URI的数据。我认为最好映射到像schema.org这样的统一模式,例如使用这个yaml文件,它试图描述数据词汇表和opengraph到schema.org之间的转换:#SchemaXtoschema.orgconversion#data-vocabularyDV:name:namestreet-address:streetAddressregion:addressRegionlocality:addressLocalityphoto:i

  7. ruby - 通过 rvm 升级 ruby​​gems 的问题 - 2

    尝试通过RVM将RubyGems升级到版本1.8.10并出现此错误:$rvmrubygemslatestRemovingoldRubygemsfiles...Installingrubygems-1.8.10forruby-1.9.2-p180...ERROR:Errorrunning'GEM_PATH="/Users/foo/.rvm/gems/ruby-1.9.2-p180:/Users/foo/.rvm/gems/ruby-1.9.2-p180@global:/Users/foo/.rvm/gems/ruby-1.9.2-p180:/Users/foo/.rvm/gems/rub

  8. ruby - 为什么 4.1%2 使用 Ruby 返回 0.0999999999999996?但是 4.2%2==0.2 - 2

    为什么4.1%2返回0.0999999999999996?但是4.2%2==0.2。 最佳答案 参见此处:WhatEveryProgrammerShouldKnowAboutFloating-PointArithmetic实数是无限的。计算机使用的位数有限(今天是32位、64位)。因此计算机进行的浮点运算不能代表所有的实数。0.1是这些数字之一。请注意,这不是与Ruby相关的问题,而是与所有编程语言相关的问题,因为它来自计算机表示实数的方式。 关于ruby-为什么4.1%2使用Ruby返

  9. ruby - ruby 中的 TOPLEVEL_BINDING 是什么? - 2

    它不等于主线程的binding,这个toplevel作用域是什么?此作用域与主线程中的binding有何不同?>ruby-e'putsTOPLEVEL_BINDING===binding'false 最佳答案 事实是,TOPLEVEL_BINDING始终引用Binding的预定义全局实例,而Kernel#binding创建的新实例>Binding每次封装当前执行上下文。在顶层,它们都包含相同的绑定(bind),但它们不是同一个对象,您无法使用==或===测试它们的绑定(bind)相等性。putsTOPLEVEL_BINDINGput

  10. ruby - 通过 RVM (OSX Mountain Lion) 安装 Ruby 2.0.0-p247 时遇到问题 - 2

    我的最终目标是安装当前版本的RubyonRails。我在OSXMountainLion上运行。到目前为止,这是我的过程:已安装的RVM$\curl-Lhttps://get.rvm.io|bash-sstable检查已知(我假设已批准)安装$rvmlistknown我看到当前的稳定版本可用[ruby-]2.0.0[-p247]输入命令安装$rvminstall2.0.0-p247注意:我也试过这些安装命令$rvminstallruby-2.0.0-p247$rvminstallruby=2.0.0-p247我很快就无处可去了。结果:$rvminstall2.0.0-p247Search

随机推荐