pytorch调用已训练图片分类模型
import os
import shutil
import time
import torch
import torch.nn as nn
from torchvision import transforms
import torch.nn.functional as F
from PIL import Image
mubiao_path = os.path.join('D:\\', 'mubiao')
weifen_path = os.path.join(mubiao_path, 'weifen')
yifen_path = mubiao_path + '\yifen'
if not os.path.exists(yifen_path):
os.mkdir(yifen_path)
class_dirs = ['1', '100']
for class_dir in class_dirs:
class_path = os.path.join(yifen_path, class_dir)
if not os.path.exists(class_path):
os.mkdir(class_path)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(64 * 56 * 56, 128)
self.fc2 = nn.Linear(128, 2)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 64 * 56 * 56)
x = torch.relu(self.fc1(x))
x = F.softmax(self.fc2(x), dim=1)
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
model.load_state_dict(torch.load('fenlei_rmb_1_10.pth', map_location=device))
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
date_start = time.process_time()
for filename in os.listdir(weifen_path):
filepath = os.path.join(weifen_path, filename)
if not filename.endswith(('.jpg', '.jpeg', '.png', '.bmp','.webp','.gif')):
continue
image = Image.open(filepath)
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image_tensor)
_, predicted = torch.max(outputs.data, 1)
predicted = predicted.item()
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
if probabilities[predicted] > 0.7:
class_dir = class_dirs[predicted]
output_path = os.path.join(yifen_path, class_dir, filename)
shutil.move(filepath, output_path)
end_date = time.process_time()
print(f'总用时:{end_date - date_start}秒')