侧边栏壁纸
博主头像
咿呀咿呀

你的脚步太乱,所以行程有限

  • 累计撰写 29 篇文章
  • 累计创建 4 个标签
  • 累计收到 2 条评论
标签搜索

垃圾分类Demo

咿呀咿呀
2022-04-23 / 0 评论 / 0 点赞 / 197 阅读 / 18,129 字
温馨提示:
本文最后更新于 2022-04-23,若内容或图片失效,请留言反馈。部分素材来自网络,若不小心影响到您的利益,请联系我们删除。

垃圾分类Demo


课设的话别全copy,过不了查重的,嘿嘿

简介

本demo由三部分构成

  • 微信小程序:拍照、上传、获取结果

  • 识别后端:由flask搭建一个REST API接口

  • Web可视化大屏

微信小程序

功能

image-20220422233659630

效果

登录页

image-20220422233743157

主页

image-20220422233806043

点击识别

image-20220422233821924

正确演示

image-20220422233848031

核心代码

小程序主页功能

//index.js
//获取应用实例
const app = getApp()

Page({
  data: {
    tempFilePaths: '',
    hidden: true,
    nocancel: true,
    image_url:'',
    result:''
  },
  confirm: function () {
    this.setData({
      hidden: !this.data.hidden,
      result: ''
    });
  },
  chooseimage: function () {
    var that = this;
    wx.showActionSheet({
      itemList: ['从相册选择', '拍照'],
      itemColor: "#9978c4",
      success: function (res) {
        if (!res.cancel) {
          if (res.tapIndex == 0) {
            that.chooseWxImage('album')
          } else if (res.tapIndex == 1) {
            that.chooseWxImage('camera')
          }
        }
      }
    })
  },
  chooseWxImage: function (type) {
    var that = this;
    wx.chooseImage({
      sizeType: ['original', 'compressed'],
      sourceType: [type],
      success: function (res) {
        console.log(res);
        that.uploadImg(res.tempFilePaths[0])
        that.setData({
          tempFilePaths: res.tempFilePaths[0],
          hidden: false
        })
      },
    })
  },
  uploadImg: function (filePath){
    var that = this;
    wx.uploadFile({
      url: 'http://localhost:5001/predict',
      filePath: filePath,
      name: 'file',
      success: function(res){
        that.setData({
          result: unescape(res.data.replace(/\\/g, "%")),
        }),
        wx.showModal({
          
          title: '识别结果', 
          confirmText: "识别正确",
          cancelText:"识别错误",
          content: res.data, 
         // fail: chooseWxImage() ,
          success: function(res) { 
            if (res.confirm) {
            console.log('识别正确')
            } else if (res.cancel) {
              var that = this;
              console.log('识别错误')

            }
            
            }
        })
      }
    })
  }
})

识别后端

设计流程

image-20220422234911755

数据集来源

数据集来源于百度AI Studio平台,共 214 种 物体 56528 张图片. 将所有垃圾分为四个大类,分别是其他垃圾、有害垃圾、厨余垃圾和可回收物,其中训练样本占90%;验证样本数量占 10%。

迁移学习

采用基于ImageNet数据集的ResNet50预训练模型。

加载ResNet50预训练模型,将ResNet50原来最后一个全连接层的输出改为本应用中的垃圾图像数据集的类别数214。使用GPU云平台训练45个epoch后,验证样本的Top1准确率为:83.227,Top5准确率为95.064,基本满足本应用的精度需求。

部署

使用Flask来部署PyTorch模型,并为模型的分类功能提供一个REST API接口。下图为本地模拟部署图片(真实环境部署,需要HTTPS证书以及合法域名):

image-20220422235432363

验证部署

使用curl文件传输工具传输测试图片到分类模块 的API接口,验证部署情况,如下图:

image-20220422235540361

核心代码

main.py

import torch
from PIL import Image
import os
import glob
from torch.utils.data import Dataset
import random
import torchvision.transforms as transforms 
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


from torch.utils.data import DataLoader
from torchvision import models
import torch.nn as nn
import torch.optim as optim
import torch
import time
import os
import shutil
os.environ["CUDA_VISIBLE_DEVICES"] = "0"



from tensorboardX import SummaryWriter

#加载数据
class Garbage_Loader(Dataset):
    def __init__(self, txt_path, train_flag=True):
        self.imgs_info = self.get_images(txt_path)
        self.train_flag = train_flag
        
        self.train_tf = transforms.Compose([
                transforms.Resize(224),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.ToTensor(),

            ])
        self.val_tf = transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
            ])
        
    def get_images(self, txt_path):
        with open(txt_path, 'r', encoding='utf-8') as f:
            imgs_info = f.readlines()
            imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))
        return imgs_info
     
    def padding_black(self, img):

        w, h  = img.size

        scale = 224. / max(w, h)
        img_fg = img.resize([int(x) for x in [w * scale, h * scale]])

        size_fg = img_fg.size
        size_bg = 224

        img_bg = Image.new("RGB", (size_bg, size_bg))

        img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,
                              (size_bg - size_fg[1]) // 2))

        img = img_bg
        return img
        
    def __getitem__(self, index):
        img_path, label = self.imgs_info[index]
        img = Image.open(img_path)
        img = img.convert('RGB')
        img = self.padding_black(img)
        if self.train_flag:
            img = self.train_tf(img)
        else:
            img = self.val_tf(img)
        label = int(label)

        return img, label
 
    def __len__(self):
        return len(self.imgs_info)

#------------------------------------------------------

def accuracy(output, target, topk=(1,)):
    """
        计算topk的准确率
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        class_to = pred[0].cpu().numpy()

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res, class_to

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """
        根据 is_best 存模型,一般保存 valid acc 最好的模型
    """
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best_' + filename)

def train(train_loader, model, criterion, optimizer, epoch, writer):
    """
        训练代码
        参数:
            train_loader - 训练集的 DataLoader
            model - 模型
            criterion - 损失函数
            optimizer - 优化器
            epoch - 进行第几个 epoch
            writer - 用于写 tensorboardX 
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        input = input.cuda()
        target = target.cuda()

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        [prec1, prec5], class_to = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 10 == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1, top5=top5))
    writer.add_scalar('loss/train_loss', losses.val, global_step=epoch)

def validate(val_loader, model, criterion, epoch, writer, phase="VAL"):
    """
        验证代码
        参数:
            val_loader - 验证集的 DataLoader
            model - 模型
            criterion - 损失函数
            epoch - 进行第几个 epoch
            writer - 用于写 tensorboardX 
    """
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            input = input.cuda()
            target = target.cuda()
            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            [prec1, prec5], class_to = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1[0], input.size(0))
            top5.update(prec5[0], input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 10 == 0:
                print('Test-{0}: [{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                              phase, i, len(val_loader),
                              batch_time=batch_time,
                              loss=losses,
                              top1=top1, top5=top5))

        print(' * {} Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
              .format(phase, top1=top1, top5=top5))
    writer.add_scalar('loss/valid_loss', losses.val, global_step=epoch)
    return top1.avg, top5.avg

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count



#-------------------------------------------------------------------------------------------
#
#
#
#------------------------------------------------------------------------------------------
if __name__ == "__main__":
    # -------------------------------------------- step 1/4 : 加载数据 ---------------------------
    train_dir_list = "train.txt"  #/openbayes/input/input0/
    valid_dir_list = "val.txt"
    batch_size = 16#64
    epochs = 80
    num_classes = 214
    train_data = Garbage_Loader(train_dir_list, train_flag=True)
    valid_data = Garbage_Loader(valid_dir_list, train_flag=False)
    train_loader = DataLoader(dataset=train_data, num_workers=1, pin_memory=True, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(dataset=valid_data, num_workers=1, pin_memory=True, batch_size=batch_size)#8
    train_data_size = len(train_data)
    print('训练集数量:%d' % train_data_size)
    valid_data_size = len(valid_data)
    print('验证集数量:%d' % valid_data_size)
    # ------------------------------------ step 2/4 : 定义网络 ------------------------------------
    model = models.resnet50(pretrained=True)
    fc_inputs = model.fc.in_features
    model.fc = nn.Linear(fc_inputs, num_classes)
    model = model.cuda()
    # ------------------------------------ step 3/4 : 定义损失函数和优化器等 -------------------------
    lr_init = 0.0001
    lr_stepsize = 20
    weight_decay = 0.001
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.Adam(model.parameters(), lr=lr_init, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_stepsize, gamma=0.1)
    
    writer = SummaryWriter('/output')
    # ------------------------------------ step 4/4 : 训练 -----------------------------------------
    best_prec1 = 0
    for epoch in range(epochs):
        scheduler.step()
        train(train_loader, model, criterion, optimizer, epoch, writer)
        # 在验证集上测试效果
        valid_prec1, valid_prec5 = validate(valid_loader, model, criterion, epoch, writer, phase="VAL")
        is_best = valid_prec1 > best_prec1
        best_prec1 = max(valid_prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': 'resnet50',
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
            }, is_best,
            filename='checkpoint_resnet50.pth.tar')
    writer.close()

falsk app.py

# coding=utf-8
from flask import Flask, render_template, request, jsonify
from werkzeug.utils import secure_filename
from datetime import timedelta
from flask import Flask, render_template, request
import torchvision.transforms as transforms
from PIL import Image
from torchvision import models
import torch
import json
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from tr import transform_image

from sql import update_class,update_xiaolie

app = Flask(__name__)


def softmax(x):
    exp_x = np.exp(x)
    softmax_x = exp_x / np.sum(exp_x, 0)
    return softmax_x


with open('dir_label.txt', 'r', encoding='utf-8') as f:
    labels = f.readlines()
    #print("oldlabels:", labels)
    labels = list(map(lambda x: x.strip().split('\t'), labels))
    print("newlabels:", labels)


def padding_black(img):
    w, h = img.size

    scale = 224. / max(w, h)
    img_fg = img.resize([int(x) for x in [w * scale, h * scale]])

    size_fg = img_fg.size
    size_bg = 224

    img_bg = Image.new("RGB", (size_bg, size_bg))

    img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,
                          (size_bg - size_fg[1]) // 2))

    img = img_bg
    return img


# 输出
@app.route('/')
def hello_world():
    return 'Hello World!'


# 设置允许的文件格式
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'JPG', 'PNG', 'bmp'])


def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS


# 设置静态文件缓存过期时间
app.send_file_max_age_default = timedelta(seconds=1)


# 添加路由
@app.route('/predict', methods=['POST', 'GET'])
def upload():
    if request.method == 'POST':
        # 通过file标签获取文件
        #
        #
        # 获取输入数据
        file = request.files['file']
        img_bytes = file.read()
        # 特征提取
        image = transform_image(img_bytes)

        # image = torch.unsqueeze(image, dim=0).float()
        #
        #

        print(image.shape)
        model = models.resnet50(pretrained=False)
        fc_inputs = model.fc.in_features
        model.fc = nn.Linear(fc_inputs, 214)
        # model = model.cuda()
        # 加载训练好的模型
        checkpoint = torch.load('model_best_checkpoint_resnet50.pth.tar')
        model.load_state_dict(checkpoint['state_dict'])
        model.eval()

        src = image.numpy()
        src = src.reshape(3, 224, 224)
        src = np.transpose(src, (1, 2, 0))
        # image = image.cuda()
        # label = label.cuda()
        pred = model(image)
        pred = pred.data.cpu().numpy()[0]

        score = softmax(pred)
        pred_id = np.argmax(score)

        plt.imshow(src)
        jieguo = labels[pred_id][0]
        print('预测结果:',jieguo )
        leibie = jieguo[:4]
        print("大类",leibie)
        update_xiaolie(jieguo)#数据入库
        update_class(leibie)#数据入库
        # if leibie == "其他垃圾":
        #     print("其他垃圾")
        # if leibie == "厨余垃圾":
        #     print("厨余垃圾")
        # if leibie == "可回收物":
        #     print("可回收物")
        # if leibie == "有害垃圾":
        #     print("有害垃圾")




        # return labels[pred_id][0];
        return json.dumps(labels[pred_id][0], ensure_ascii=False)  # 将预测结果传回给前端
        # plt.show()
    #     return render_template('upload_ok.html')
    #     重新返回上传界面
    # return render_template('upload.html')


if __name__ == '__main__':
    # curl -X POST -F file=@dc.jpg http://localhost:5001/predict
    app.run(debug=False,
            port=5001)

后端数据入数据库

import time
import pymysql

#return: 连接,游标
def get_conn():

    # 创建连接
    conn = pymysql.connect(host="**************",
                           port=3306,
                           user="root",
                           password="*******",
                           db="data",
                           charset="utf8")
    # 创建游标
    cursor = conn.cursor()# 执行完毕返回的结果集默认以元组显示
    return conn, cursor

def close_conn(conn, cursor):
    cursor.close()
    conn.close()


def query(sql,*args):
    """
    封装通用查询
    :param sql:
    :param args:
    :return: 返回查询到的结果,((),(),)的形式
    """
    conn, cursor = get_conn()
    cursor.execute(sql,args)
    res = cursor.fetchall()
    close_conn(conn, cursor)
    return res

def update_class(aim):
    """
    更新大类
    :param sql:
    :param args:
    :return:
    """
    sql1 = "UPDATE data.class_time SET number=number+1 WHERE name='{}'".format(aim)
    conn, cursor = get_conn()
    try:
        cursor.execute(sql1)
        conn.commit()
    except:
        conn.rollback()
        print("error")
    close_conn(conn, cursor)

def update_xiaolie(aim):
    '''
    更新小类
    :param aim:
    :return:
    '''
    sql1 = "UPDATE data.20210502 SET number=number+1 WHERE name='{}'".format(aim)
    conn, cursor = get_conn()
    try:
        cursor.execute(sql1)
        conn.commit()
    except:
        conn.rollback()
        print("error")
    close_conn(conn, cursor)

Web可视化大屏

效果

image-20220423001421362

核心代码

flask app.py

from flask import Flask
from flask import render_template
from flask import jsonify

import utils
import string

app = Flask(__name__)

@app.route('/time')
def gettime():
	return utils.get_time()

@app.route('/')
def hello_world():
    return 'Hello World!'

@app.route('/index')
def hello3():
    return render_template("main.html")

@app.route('/c1')
def get_c1_data():
    data = utils.get_c1_data()
    return jsonify({"其他垃圾":data[0],"有害垃圾":data[1],"厨余垃圾":data[2],"可回收物":data[3]})

@app.route('/c2')
def get_c2_data():
    res = []
    for tup in utils.get_c2_data():
        res.append({"name":tup[0][5:],"value":int(tup[1])})

    return jsonify({"kws":res})


@app.route("/r1")
def get_r1_data():
    data = utils.get_r1_data()
    class_ = []
    number_ = []
    for k,v in data:
        class_.append(k[5:])
        number_.append(int(v))
    return jsonify({"class_": class_, "number_": number_})

@app.route("/r2")
def get_r2_data():
    data = utils.get_r2_data()
    d = []
    for i in data:
        k = i[0][5:]
        v = i[1]
        d.append({"name": k, "value": v})
    return jsonify({"kws": d})


@app.route("/l2")
def get_l2_data():
    data = utils.get_l2_data()
    class_l2 = []
    number_l2 = []
    for k,v in data:
        class_l2.append(k[5:])
        number_l2.append(int(v))
    return jsonify({"class_l2": class_l2, "number_l2": number_l2})

@app.route("/l1")
def get_l1_data():
    data = utils.get_l1_data()
    class_l1 = []
    number_l1 = []
    d = []
    for k,v in data:
        class_l1.append(k)
        number_l1.append(int(v))
        d.append({"name": k, "value": v})
    return jsonify({"kws": d})



@app.route('/ajax',methods=['get','post'])
def hello4():
    return '500'

if __name__ == '__main__':
    app.run()

utils.py

import time
import pymysql


def get_time():
    time_str =  time.strftime("%Y{}%m{}%d{} %X")
    return time_str.format("年","月","日")     #因为直接写不支持直接识别中文,才用format写


#return: 连接,游标
def get_conn():

    # 创建连接
    conn = pymysql.connect(host="*******************",
                           port=3306,
                           user="root",
                           password="****************",
                           db="data",
                           charset="utf8")
    # 创建游标
    cursor = conn.cursor()# 执行完毕返回的结果集默认以元组显示
    return conn, cursor


def close_conn(conn, cursor):
    cursor.close()
    conn.close()

def query(sql,*args):
    """
    封装通用查询
    :param sql:
    :param args:
    :return: 返回查询到的结果,((),(),)的形式
    """
    conn, cursor = get_conn()
    cursor.execute(sql,args)
    res = cursor.fetchall()
    close_conn(conn, cursor)
    return res


def get_c1_data():
    """
    :return: 返回大屏div id=c1 的数据
    """

   # sql = "select 其他垃圾,有害垃圾,厨余垃圾,可回收物 from class_time where time=(select time from class_time order by time desc limit 1)"
    sql1 = "select number from data.class_time where name = '其他垃圾'"
    sql2 = "select number from data.class_time where name = '有害垃圾'"
    sql3 = "select number from data.class_time where name = '厨余垃圾'"
    sql4 = "select number from data.class_time where name = '可回收物'"
    res1 = query(sql1)
    res2 = query(sql2)

    res3 = query(sql3)
    res4 = query(sql4)
    res = res1[0] + res2[0]+res3[0]+res4[0]
    return res

#--------------------------------------返回To5数据----------------------------------------------------------
def get_r1_data():

    sql = "SELECT name,number from data.20210502 ORDER BY number DESC LIMIT 5;"
    res = query(sql)
    return res

#------------------------------------------------------------------------------------------------------------

def get_r2_data():
    #
    sql = "SELECT name,number from data.20210502 ORDER BY number DESC LIMIT 20;"
    res = query(sql)
    return res

#--------------------------------------返回随机20条数据----------------------------------------------------------
def get_l2_data():
      #
    sql = "SELECT name,number from data.20210502 ORDER BY RAND() LIMIT 20;"
    res = query(sql)
    return res

#--------------------------------------c2----------------------------------------------------------
def get_c2_data():
      # 因为会更新多次数据,取时间戳最新的那组数据
    sql = "SELECT name,number from data.20210502 ORDER BY number DESC LIMIT 15;"
    res = query(sql)
    return res


#--------------------------------------返回四大类数据----------------------------------------------------------
def get_l1_data():
    sql1 = "select name,number from data.class_time where name = '其他垃圾'"
    sql2 = "select name,number from data.class_time where name = '有害垃圾'"
    sql3 = "select name,number from data.class_time where name = '厨余垃圾'"
    sql4 = "select name,number from data.class_time where name = '可回收物'"
    res1 = query(sql1)
    res2 = query(sql2)

    res3 = query(sql3)
    res4 = query(sql4)
    res = res1 + res2 + res3 + res4
    return res


if __name__ == "__main__":
    print('--------------------------------------')
    print(get_c2_data())

参考

[2] 王爽.微信小程序在垃圾分类中的应用研究[J].信息与电脑,2019(22):66-68

[3] 黑马程序员.微信小程序开发实战[M].人民邮电出版社 2019年4月,8

[4] 肖智清.神经网络与Pytorch实战[M].机械工业出版社.2018年1月,40

[7] [美]Miguel Grinberg著,安道译.Flask Web开发:基于Python的Web应用开发实战[M].人民邮电出版社,2014:前言1

[8] 王大伟.Echarts数据可视化:入门、实战与进阶[M].机械工业出版社,2020,11

[9] 魏喆.基于卷积神经网络的图像分类算法的研究[D].北京工业大学,2018

[10] HE K,ZHANG X,REN S,et al.Deep residual learing for image recognition[C].Proceedings of the IEEE Conference on Computer Vision and Pattern Pattern Recognition,2016:770-778

[11] 郑誉煌,戴冰燕,熊泽珲.基于迁移学习的可回收生活垃圾图像分类识别研究[J].广东第二师范学院学报,2020(3):94

0

评论区