草庐IT

FCOS论文复现:通用物体检测算法

华为云开发者社区 2023-03-28 原文
摘要:本案例代码是FCOS论文复现的体验案例,此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现。本代码支持FCOS + ResNet-101在MS-COCO数据集上完整的训练和测试流程

本文分享自华为云社区《通用物体检测算法 FCOS(目标检测/Pytorch)》,作者: HWCloudAI 。

FCOS:Fully Convolutional One-Stage Object Detection

本案例代码是FCOS论文复现的体验案例

此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现。该算法使用MS-COCO公共数据集进行训练和评估。本代码支持FCOS + ResNet-101在MS-COCO数据集上完整的训练和测试流程

具体的算法介绍:https://marketplace.huaweicloud.com/markets/aihub/modelhub/detail/?id=ce7acc40-0540-45c9-a0c6-e2fda8d1ac7e

注意事项:

1.本案例使用框架: PyTorch1.0.0

2.本案例使用硬件: GPU

3.运行代码方法: 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码

1.数据和代码下载

import os
import moxing as mox
# 数据代码下载
mox.file.copy_parallel('obs://obs-aigallery-zc/algorithm/FCOS.zip','FCOS.zip')
# 解压缩
os.system('unzip  FCOS.zip -d ./')

2.模型训练

2.1依赖库安装及加载

"""
Basic training script for PyTorch
"""
# Set up custom environment before nearly anything else is imported
# NOTE: this should be the first import (no not reorder)
import os
import argparse
import torch
import shutil
src_dir = './FCOS/'
os.chdir(src_dir)
os.system('pip install -r ./pip-requirements.txt')
os.system('python -m pip install ./trained_model/model/framework-2.0-cp36-cp36m-linux_x86_64.whl')
os.system('python setup.py build develop')
from framework.utils.env import setup_environment
from framework.config import cfg
from framework.data import make_data_loader
from framework.solver import make_lr_scheduler
from framework.solver import make_optimizer
from framework.engine.inference import inference
from framework.engine.trainer import do_train
from framework.modeling.detector import build_detection_model
from framework.utils.checkpoint import DetectronCheckpointer
from framework.utils.collect_env import collect_env_info
from framework.utils.comm import synchronize, \
 get_rank, is_pytorch_1_1_0_or_later
from framework.utils.logger import setup_logger
from framework.utils.miscellaneous import mkdir

2.2训练函数

def train(cfg, local_rank, distributed, new_iteration=False):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
 if cfg.MODEL.USE_SYNCBN:
 assert is_pytorch_1_1_0_or_later(), \
 "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)
 if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
 # this should be removed if we update BatchNorm stats
 broadcast_buffers=False,
 )
    arguments = {}
    arguments["iteration"] = 0
 output_dir = cfg.OUTPUT_DIR
 save_to_disk = get_rank() == 0
 checkpointer = DetectronCheckpointer(
 cfg, model, optimizer, scheduler, output_dir, save_to_disk
 )
 print(cfg.MODEL.WEIGHT)
 extra_checkpoint_data = checkpointer.load_from_file(cfg.MODEL.WEIGHT)
 print(extra_checkpoint_data)
 arguments.update(extra_checkpoint_data)
 if new_iteration:
        arguments["iteration"] = 0
 data_loader = make_data_loader(
 cfg,
 is_train=True,
 is_distributed=distributed,
 start_iter=arguments["iteration"],
 )
 do_train(
        model,
 data_loader,
        optimizer,
        scheduler,
 checkpointer,
        device,
        arguments,
 )
 return model

2.3设置参数,开始训练

def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
 parser.add_argument(
 '--train_url',
        default='./outputs',
 type=str,
 help='the path to save training outputs'
 )
 parser.add_argument(
 "--config-file",
        default="./trained_model/model/fcos_resnet_101_fpn_2x.yaml",
 metavar="FILE",
 help="path to config file",
 type=str,
 )
 parser.add_argument("--local_rank", type=int, default=0)
 parser.add_argument('--train_iterations', default=0, type=int)
 parser.add_argument('--warmup_iterations', default=500, type=int)
 parser.add_argument('--train_batch_size', default=8, type=int)
 parser.add_argument('--solver_lr', default=0.01, type=float)
 parser.add_argument('--decay_steps', default='120000,160000', type=str)
 parser.add_argument('--new_iteration',default=False, action='store_true')
 args, unknown = parser.parse_known_args()
 cfg.merge_from_file(args.config_file)
 # load the model trained on MS-COCO
 if args.train_iterations > 0:
 cfg.SOLVER.MAX_ITER = args.train_iterations
 if args.warmup_iterations > 0:
 cfg.SOLVER.WARMUP_ITERS = args.warmup_iterations
 if args.train_batch_size > 0:
 cfg.SOLVER.IMS_PER_BATCH = args.train_batch_size
 if args.solver_lr > 0:
 cfg.SOLVER.BASE_LR = args.solver_lr
 if len(args.decay_steps) > 0:
        steps = args.decay_steps.replace(' ', ',')
        steps = steps.replace(';', ',')
        steps = steps.replace('', ',')
        steps = steps.replace('', ',')
        steps = steps.split(',')
        steps = tuple([int(x) for x in steps])
 cfg.SOLVER.STEPS = steps
 cfg.freeze()
 num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
 args.distributed = num_gpus > 1
 if args.distributed:
 torch.cuda.set_device(args.local_rank)
 torch.distributed.init_process_group(
            backend="nccl", init_method="env://"
 )
 synchronize()
 output_dir = args.train_url
 if output_dir:
 mkdir(output_dir)
    logger = setup_logger("framework", output_dir, get_rank())
 logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)
 logger.info("Loaded configuration file {}".format(args.config_file))
 train(cfg, args.local_rank, args.distributed, args.new_iteration)
if __name__ == "__main__":
 main()

3.模型测试

3.1预测函数

from framework.engine.predictor import Predictor
from PIL import Image,ImageDraw
import numpy as np
import matplotlib.pyplot as plt
def predict(img_path,model_path): 
 config_file = "./trained_model/model/fcos_resnet_101_fpn_2x.yaml"
 cfg.merge_from_file(config_file)
 cfg.defrost()
 cfg.MODEL.WEIGHT = model_path
 cfg.OUTPUT_DIR = None
 cfg.freeze()
    predictor = Predictor(cfg=cfg, min_image_size=800)
 src_img = Image.open(img_path)
 img = src_img.convert('RGB')
 img = np.array(img)
 img = img[:, :, ::-1]
    predictions = predictor.compute_prediction(img)
 top_predictions = predictor.select_top_predictions(predictions)
 bboxes = top_predictions.bbox.int().numpy().tolist()
 bboxes = [[x[1], x[0], x[3], x[2]] for x in bboxes]
    scores = top_predictions.get_field("scores").numpy().tolist()
    scores = [round(x, 4) for x in scores]
    labels = top_predictions.get_field("labels").numpy().tolist()
    labels = [predictor.CATEGORIES[x] for x in labels]
    draw = ImageDraw.Draw(src_img)
 for i,bbox in enumerate(bboxes):
 draw.text((bbox[1],bbox[0]),labels[i] + ':'+str(scores[i]),fill=(255,0,0))
 draw.rectangle([bbox[1],bbox[0],bbox[3],bbox[2]],fill=None,outline=(255,0,0))
 return src_img

3.2开始预测

if __name__ == "__main__":
 model_path = "./outputs/weights/fcos_resnet_101_fpn_2x/model_final.pth" # 训练得到的模型
 image_path = "./trained_model/model/demo_image.jpg" # 预测的图像
 img = predict(image_path,model_path)
 plt.figure(figsize=(10,10)) #设置窗口大小
 plt.imshow(img)
 plt.show()
2021-06-09 15:33:15,362 framework.utils.checkpoint INFO: Loading checkpoint from ./outputs/weights/fcos_resnet_101_fpn_2x/model_final.pth

 

点击关注,第一时间了解华为云新鲜技术~

有关FCOS论文复现:通用物体检测算法的更多相关文章

  1. ruby - RuntimeError(自动加载常量 Apps 多线程时检测到循环依赖 - 2

    我收到这个错误:RuntimeError(自动加载常量Apps时检测到循环依赖当我使用多线程时。下面是我的代码。为什么会这样?我尝试多线程的原因是因为我正在编写一个HTML抓取应用程序。对Nokogiri::HTML(open())的调用是一个同步阻塞调用,需要1秒才能返回,我有100,000多个页面要访问,所以我试图运行多个线程来解决这个问题。有更好的方法吗?classToolsController0)app.website=array.join(',')putsapp.websiteelseapp.website="NONE"endapp.saveapps=Apps.order("

  2. 区块链之加解密算法&数字证书 - 2

    目录一.加解密算法数字签名对称加密DES(DataEncryptionStandard)3DES(TripleDES)AES(AdvancedEncryptionStandard)RSA加密法DSA(DigitalSignatureAlgorithm)ECC(EllipticCurvesCryptography)非对称加密签名与加密过程非对称加密的应用对称加密与非对称加密的结合二.数字证书图解一.加解密算法加密简单而言就是通过一种算法将明文信息转换成密文信息,信息的的接收方能够通过密钥对密文信息进行解密获得明文信息的过程。根据加解密的密钥是否相同,算法可以分为对称加密、非对称加密、对称加密和非

  3. [Vuforia]二.3D物体识别 - 2

    之前说过10之后的版本没有3dScan了,所以还是9.8的版本或者之前更早的版本。 3d物体扫描需要先下载扫描的APK进行扫面。首先要在手机上装一个扫描程序,扫描现实中的三维物体,然后上传高通官网,在下载成UnityPackage类型让Unity能够使用这个扫描程序可以从高通官网上进行下载,是一个安卓程序。点到Tools往下滑,找到VuforiaObjectScanner下载后解压数据线连接手机,将apk文件拷入手机安装然后刚才解压文件中的Media文件夹打开,两个PDF图打印第一张A4-ObjectScanningTarget.pdf,主要是用来辅助扫描的。好了,接下来就是扫描三维物体。将瓶

  4. ruby - 检测由 RSpec、Ruby 运行的代码 - 2

    我想知道我的代码是否在rspec下运行。这可能吗?原因是我正在加载一些错误记录器,这些记录器在测试期间会被故意错误(expect{x}.toraise_error)弄得乱七八糟。我查看了我的ENV变量,没有(明显的)测试环境变量的迹象。 最佳答案 在spec_helper.rb的开头添加:ENV['RACK_ENV']='test'现在您可以在代码中检查RACK_ENV是否经过测试。 关于ruby-检测由RSpec、Ruby运行的代码,我们在StackOverflow上找到一个类似的问题

  5. ruby - 使用 Ruby Daemons gem 检测停止 - 2

    我正在使用rubydaemongem。想知道如何向停止操作添加一些额外的步骤?希望我能检测到停止被调用,并向其添加一些额外的代码。任何人都知道我如何才能做到这一点? 最佳答案 查看守护程序gem代码,它似乎没有用于此目的的明显扩展点。但是,我想知道(在守护进程中)您是否可以捕获守护进程在发生“停止”时发送的KILL/TERM信号...?trap("TERM")do#executeyourextracodehereend或者你可以安装一个at_exit钩子(Hook):-at_exitdo#executeyourextracodehe

  6. ruby - Ruby 脚本如何检测到它正在 irb 中运行? - 2

    我有一个定义类的Ruby脚本。我希望脚本执行语句BoolParser.generate:file_base=>'bool_parser'仅当脚本作为可执行文件被调用时,而不是当它被irbrequire(或通过-r在命令行上传递)时。我可以用什么来包装上面的语句,以防止它在我的Ruby文件加载时执行? 最佳答案 条件$0==__FILE__...!/usr/bin/ruby1.8classBoolParserdefself.generate(args)p['BoolParser.generate',args]endendif$0==_

  7. Ruby 无法检测字符串中的换行符 - 2

    我有以下字符串,我想检测那里的换行符。但是Ruby的字符串方法include?检测不到它。我正在运行Ruby1.9.2p290。我哪里出错了?"/'ædres/\nYour".include?('\n')=>false 最佳答案 \n需要在双引号内,否则无法转义。>>"\n".include?'\n'=>false>>"\n".include?"\n"=>true 关于Ruby无法检测字符串中的换行符,我们在StackOverflow上找到一个类似的问题: h

  8. 100个python算法超详细讲解:画直线 - 2

    1.问题描述使用Python的turtle(海龟绘图)模块提供的函数绘制直线。2.问题分析一幅复杂的图形通常都可以由点、直线、三角形、矩形、平行四边形、圆、椭圆和圆弧等基本图形组成。其中的三角形、矩形、平行四边形又可以由直线组成,而直线又是由两个点确定的。我们使用Python的turtle模块所提供的函数来绘制直线。在使用之前我们先介绍一下turtle模块的相关知识点。turtle模块提供面向对象和面向过程两种形式的海龟绘图基本组件。面向对象的接口类如下:1)TurtleScreen类:定义图形窗口作为绘图海龟的运动场。它的构造器需要一个tkinter.Canvas或ScrolledCanva

  9. 【自动驾驶环境感知项目】——基于Paddle3D的点云障碍物检测 - 2

    文章目录1.自动驾驶实战:基于Paddle3D的点云障碍物检测1.1环境信息1.2准备点云数据1.3安装Paddle3D1.4模型训练1.5模型评估1.6模型导出1.7模型部署效果附录show_lidar_pred_on_image.py1.自动驾驶实战:基于Paddle3D的点云障碍物检测项目地址——自动驾驶实战:基于Paddle3D的点云障碍物检测课程地址——自动驾驶感知系统揭秘1.1环境信息硬件信息CPU:2核AI加速卡:v100总显存:16GB总内存:16GB总硬盘:100GB环境配置Python:3.7.4框架信息框架版本:PaddlePaddle2.4.0(项目默认框架版本为2.3

  10. ruby - 在 Ruby 中实现 Luhn 算法 - 2

    我一直在尝试用Ruby实现Luhn算法。我一直在执行以下步骤:该公式根据其包含的校验位验证数字,该校验位通常附加到部分帐号以生成完整帐号。此帐号必须通过以下测试:从最右边的校验位开始向左移动,每第二个数字的值加倍。将乘积的数字(例如,10=1+0=1、14=1+4=5)与原始数字的未加倍数字相加。如果总模10等于0(如果总和以零结尾),则根据Luhn公式该数字有效;否则无效。http://en.wikipedia.org/wiki/Luhn_algorithm这是我想出的:defvalidCreditCard(cardNumber)sum=0nums=cardNumber.to_s.s

随机推荐