diff --git a/README.md b/README.md index 3ea98288..f7a99433 100644 --- a/README.md +++ b/README.md @@ -64,9 +64,31 @@ ATTENTION: The pip and conda packages of PyTorch3D have different dependencies, Run this demo with specified FFHQ image name and computing device, ``` -python photometric_fitting.py 00000 cuda +python demos/photometric_fitting.py 00000 cuda ``` +Run custom image, +``` +python demos/wj_fitting.py FFHQ/00000.png cuda +``` + +Run reconstruct face and driving expression, +``` +python demos/exp_with_texture.py video.mp4 cuda +``` + +Run transfer expression, +``` +python demos/transfer_exp.py video.mp4 basic_model.npy cuda +``` + +facial landmark [face-alignment](https://github.com/1adrianb/face-alignment) + +face segmentation [face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch) + +related model can be found [Google Cloud](https://drive.google.com/file/d/1_vyhJUiy3y9DyrtPRkqBq00nkpn766c-/view?usp=sharing) or [Baidu Yun](https://pan.baidu.com/s/1S-CYb3KFk2CI02HU_3jUKA) code:1emq + + Another simple demo to sample the texture space can be found [here](https://github.com/TimoBolkart/TF_FLAME). diff --git a/demos/__init__.py b/demos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/demos/exp_with_texture.py b/demos/exp_with_texture.py new file mode 100644 index 00000000..0d53f18d --- /dev/null +++ b/demos/exp_with_texture.py @@ -0,0 +1,170 @@ +import os, sys +import cv2 +import torch +import torchvision +import torch.nn.functional as F +import torch.nn as nn +import numpy as np +import datetime +sys.path.append(".") +from utils.renderer import Renderer +from utils import util +from utils.config import cfg +from facial_alignment.detection import sfd_detector as detector +from facial_alignment.detection import FAN_landmark +from models.face_seg_model import BiSeNet +from models.FLAME import FLAME, FLAMETex + +torch.backends.cudnn.benchmark = True + + +class PhotometricFitting(object): + def __init__(self, device='cuda'): + self.config = cfg + self.device = device + self.flame = FLAME(self.config).to(self.device) + self.flametex = FLAMETex(self.config).to(self.device) + + self._setup_renderer() + + def _setup_renderer(self): + self.render = Renderer(cfg.image_size, obj_filename=cfg.mesh_file).to(self.device) + + def optimize(self, images, landmarks, image_masks, all_param, video_writer, first_flag): + shape_para, tex_para, exp_para, pose_para, cam_para, lights_para = all_param + e_opt = torch.optim.Adam( + [shape_para, exp_para, pose_para, cam_para, tex_para, lights_para], + lr=cfg.e_lr, + weight_decay=cfg.e_wd + ) + d_opt = torch.optim.Adam( + [shape_para, exp_para, pose_para, cam_para], + lr=cfg.e_lr, + weight_decay=cfg.e_wd + ) + + gt_landmark = landmarks + max_iter = 50 + if first_flag: + max_iter = cfg.max_iter + + tmp_predict = torch.squeeze(images) + for k in range(0, max_iter): + losses = {} + vertices, landmarks2d, landmarks3d = self.flame(shape_params=shape_para, expression_params=exp_para, + pose_params=pose_para) + trans_vertices = util.batch_orth_proj(vertices, cam_para) + trans_vertices[..., 1:] = - trans_vertices[..., 1:] + landmarks2d = util.batch_orth_proj(landmarks2d, cam_para) + landmarks2d[..., 1:] = - landmarks2d[..., 1:] + landmarks3d = util.batch_orth_proj(landmarks3d, cam_para) + landmarks3d[..., 1:] = - landmarks3d[..., 1:] + + losses['landmark'] = util.l2_distance(landmarks2d[:, :, :2], gt_landmark[:, :, :2]) + + # render + albedos = self.flametex(tex_para) / 255. + ops = self.render(vertices, trans_vertices, albedos, lights_para) + tmp_predict = torchvision.utils.make_grid(ops['images'][0].detach().float().cpu()) + # losses['photometric_texture'] = (image_masks * (ops['images'] - images).abs()).mean() * config.w_pho + if first_flag: + losses['photometric_texture'] = F.smooth_l1_loss(image_masks * ops['images'], + image_masks * images) * cfg.w_pho + + all_loss = 0. + for key in losses.keys(): + all_loss = all_loss + losses[key] + losses['all_loss'] = all_loss + if first_flag: + e_opt.zero_grad() + all_loss.backward() + e_opt.step() + else: + d_opt.zero_grad() + all_loss.backward() + d_opt.step() + loss_info = '----iter: {}, time: {}\n'.format(k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')) + for key in losses.keys(): + loss_info = loss_info + '{}: {}, '.format(key, float(losses[key])) + print(loss_info) + + # tmp_predict = torchvision.utils.make_grid(ops['images'][0].detach().float().cpu()) + tmp_predict = (tmp_predict.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]] + tmp_predict = np.minimum(np.maximum(tmp_predict, 0), 255).astype(np.uint8) + + tmp_image = torchvision.utils.make_grid(images[0].detach().float().cpu()) + tmp_image = (tmp_image.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]] + tmp_image = np.minimum(np.maximum(tmp_image, 0), 255).astype(np.uint8) + combine = np.concatenate((tmp_predict, tmp_image), axis=1) + cv2.imshow("tmp_image", combine) + cv2.waitKey(1) + video_writer.write(combine) + return [shape_para, tex_para, exp_para, pose_para, cam_para, lights_para] + + def run(self, img, net, rect_detect, landmark_detect, all_param, rect_thresh, out, first_flag): + # The implementation is potentially able to optimize with images(batch_size>1), + # here we show the example with a single image fitting + images = [] + landmarks = [] + image_masks = [] + bbox = rect_detect.extract(img, rect_thresh) + if len(bbox) > 0: + crop_image, new_bbox = util.crop_img(img, bbox[0], cfg.cropped_size) + + resize_img, landmark = landmark_detect.extract([crop_image, [new_bbox]]) + landmark = landmark[0] + landmark[:, 0] = landmark[:, 0] / float(resize_img.shape[1]) * 2 - 1 + landmark[:, 1] = landmark[:, 1] / float(resize_img.shape[0]) * 2 - 1 + landmarks.append(torch.from_numpy(landmark)[None, :, :].double().to(self.device)) + landmarks = torch.cat(landmarks, dim=0) + + image = cv2.resize(crop_image, (cfg.cropped_size, cfg.cropped_size)).astype(np.float32) / 255. + image = image[:, :, ::-1].transpose(2, 0, 1).copy() + images.append(torch.from_numpy(image[None, :, :, :]).double().to(self.device)) + images = torch.cat(images, dim=0) + images = F.interpolate(images, [cfg.image_size, cfg.image_size]) + + image_mask = util.face_seg(crop_image, net, cfg.cropped_size) + image_masks.append(torch.from_numpy(image_mask).double().to(self.device)) + image_masks = torch.cat(image_masks, dim=0) + image_masks = F.interpolate(image_masks, [cfg.image_size, cfg.image_size]) + + single_params = self.optimize(images, landmarks, image_masks, all_param, out, first_flag) + return single_params + + +if __name__ == '__main__': + video_path = str(sys.argv[1]) + device_name = str(sys.argv[2]) + util.check_mkdir(cfg.save_folder) + fitting = PhotometricFitting(device=device_name) + save_video_name = os.path.split(video_path)[1].split(".")[0] + '.avi' + video_writer = cv2.VideoWriter(os.path.join(cfg.save_folder, save_video_name), + cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 16, + (cfg.image_size * 2, cfg.image_size)) + cap = cv2.VideoCapture(video_path) + ret, frame = cap.read() + + if ret: + w_h_scale = util.resize_para(frame) + face_detect = detector.SFDDetector(device_name, cfg.rect_model_path) + face_landmark = FAN_landmark.FANLandmarks(device_name, cfg.landmark_model_path, cfg.face_detect_type) + seg_net = BiSeNet(n_classes=cfg.seg_class) + seg_net.cuda() + seg_net.load_state_dict(torch.load(cfg.face_seg_model)) + seg_net.eval() + first_flag = True + shape = nn.Parameter(torch.zeros(cfg.batch_size, cfg.shape_params).float().to(device_name)) + tex = nn.Parameter(torch.zeros(cfg.batch_size, cfg.tex_params).float().to(device_name)) + exp = nn.Parameter(torch.zeros(cfg.batch_size, cfg.expression_params).float().to(device_name)) + pose = nn.Parameter(torch.zeros(cfg.batch_size, cfg.pose_params).float().to(device_name)) + cam = torch.zeros(cfg.batch_size, cfg.camera_params) + cam[:, 0] = 5. + cam = nn.Parameter(cam.float().to(device_name)) + lights = nn.Parameter(torch.zeros(cfg.batch_size, 9, 3).float().to(device_name)) + all_params = [shape, tex, exp, pose, cam, lights] + while ret: + all_params = fitting.run(frame, seg_net, face_detect, face_landmark, all_params, + cfg.rect_thresh, video_writer, first_flag) + first_flag = False + ret, frame = cap.read() diff --git a/photometric_fitting.py b/demos/photometric_fitting.py similarity index 75% rename from photometric_fitting.py rename to demos/photometric_fitting.py index 4969a22d..27a49696 100644 --- a/photometric_fitting.py +++ b/demos/photometric_fitting.py @@ -1,265 +1,236 @@ -import os, sys -import cv2 -import torch -import torchvision -import torch.nn.functional as F -import torch.nn as nn -import numpy as np -from glob import glob -import time -import datetime -import imageio - -sys.path.append('./models/') -from FLAME import FLAME, FLAMETex -from renderer import Renderer -import util -torch.backends.cudnn.benchmark = True - - -class PhotometricFitting(object): - def __init__(self, config, device='cuda'): - self.batch_size = config.batch_size - self.image_size = config.image_size - self.config = config - self.device = device - # - self.flame = FLAME(self.config).to(self.device) - self.flametex = FLAMETex(self.config).to(self.device) - - self._setup_renderer() - - def _setup_renderer(self): - mesh_file = './data/head_template_mesh.obj' - self.render = Renderer(self.image_size, obj_filename=mesh_file).to(self.device) - - def optimize(self, images, landmarks, image_masks, savefolder=None): - bz = images.shape[0] - shape = nn.Parameter(torch.zeros(bz, self.config.shape_params).float().to(self.device)) - tex = nn.Parameter(torch.zeros(bz, self.config.tex_params).float().to(self.device)) - exp = nn.Parameter(torch.zeros(bz, self.config.expression_params).float().to(self.device)) - pose = nn.Parameter(torch.zeros(bz, self.config.pose_params).float().to(self.device)) - cam = torch.zeros(bz, self.config.camera_params); cam[:, 0] = 5. - cam = nn.Parameter(cam.float().to(self.device)) - lights = nn.Parameter(torch.zeros(bz, 9, 3).float().to(self.device)) - e_opt = torch.optim.Adam( - [shape, exp, pose, cam, tex, lights], - lr=self.config.e_lr, - weight_decay=self.config.e_wd - ) - e_opt_rigid = torch.optim.Adam( - [pose, cam], - lr=self.config.e_lr, - weight_decay=self.config.e_wd - ) - - gt_landmark = landmarks - - # rigid fitting of pose and camera with 51 static face landmarks, - # this is due to the non-differentiable attribute of contour landmarks trajectory - for k in range(200): - losses = {} - vertices, landmarks2d, landmarks3d = self.flame(shape_params=shape, expression_params=exp, pose_params=pose) - trans_vertices = util.batch_orth_proj(vertices, cam); - trans_vertices[..., 1:] = - trans_vertices[..., 1:] - landmarks2d = util.batch_orth_proj(landmarks2d, cam); - landmarks2d[..., 1:] = - landmarks2d[..., 1:] - landmarks3d = util.batch_orth_proj(landmarks3d, cam); - landmarks3d[..., 1:] = - landmarks3d[..., 1:] - - losses['landmark'] = util.l2_distance(landmarks2d[:, 17:, :2], gt_landmark[:, 17:, :2]) * config.w_lmks - - all_loss = 0. - for key in losses.keys(): - all_loss = all_loss + losses[key] - losses['all_loss'] = all_loss - e_opt_rigid.zero_grad() - all_loss.backward() - e_opt_rigid.step() - - loss_info = '----iter: {}, time: {}\n'.format(k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')) - for key in losses.keys(): - loss_info = loss_info + '{}: {}, '.format(key, float(losses[key])) - if k % 10 == 0: - print(loss_info) - - if k % 10 == 0: - grids = {} - visind = range(bz) # [0] - grids['images'] = torchvision.utils.make_grid(images[visind]).detach().cpu() - grids['landmarks_gt'] = torchvision.utils.make_grid( - util.tensor_vis_landmarks(images[visind], landmarks[visind])) - grids['landmarks2d'] = torchvision.utils.make_grid( - util.tensor_vis_landmarks(images[visind], landmarks2d[visind])) - grids['landmarks3d'] = torchvision.utils.make_grid( - util.tensor_vis_landmarks(images[visind], landmarks3d[visind])) - - grid = torch.cat(list(grids.values()), 1) - grid_image = (grid.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]] - grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8) - cv2.imwrite('{}/{}.jpg'.format(savefolder, k), grid_image) - - # non-rigid fitting of all the parameters with 68 face landmarks, photometric loss and regularization terms. - for k in range(200, 1000): - losses = {} - vertices, landmarks2d, landmarks3d = self.flame(shape_params=shape, expression_params=exp, pose_params=pose) - trans_vertices = util.batch_orth_proj(vertices, cam); - trans_vertices[..., 1:] = - trans_vertices[..., 1:] - landmarks2d = util.batch_orth_proj(landmarks2d, cam); - landmarks2d[..., 1:] = - landmarks2d[..., 1:] - landmarks3d = util.batch_orth_proj(landmarks3d, cam); - landmarks3d[..., 1:] = - landmarks3d[..., 1:] - - losses['landmark'] = util.l2_distance(landmarks2d[:, :, :2], gt_landmark[:, :, :2]) * config.w_lmks - losses['shape_reg'] = (torch.sum(shape ** 2) / 2) * config.w_shape_reg # *1e-4 - losses['expression_reg'] = (torch.sum(exp ** 2) / 2) * config.w_expr_reg # *1e-4 - losses['pose_reg'] = (torch.sum(pose ** 2) / 2) * config.w_pose_reg - - ## render - albedos = self.flametex(tex) / 255. - ops = self.render(vertices, trans_vertices, albedos, lights) - predicted_images = ops['images'] - losses['photometric_texture'] = (image_masks * (ops['images'] - images).abs()).mean() * config.w_pho - - all_loss = 0. - for key in losses.keys(): - all_loss = all_loss + losses[key] - losses['all_loss'] = all_loss - e_opt.zero_grad() - all_loss.backward() - e_opt.step() - - loss_info = '----iter: {}, time: {}\n'.format(k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')) - for key in losses.keys(): - loss_info = loss_info + '{}: {}, '.format(key, float(losses[key])) - - if k % 10 == 0: - print(loss_info) - - # visualize - if k % 10 == 0: - grids = {} - visind = range(bz) # [0] - grids['images'] = torchvision.utils.make_grid(images[visind]).detach().cpu() - grids['landmarks_gt'] = torchvision.utils.make_grid( - util.tensor_vis_landmarks(images[visind], landmarks[visind])) - grids['landmarks2d'] = torchvision.utils.make_grid( - util.tensor_vis_landmarks(images[visind], landmarks2d[visind])) - grids['landmarks3d'] = torchvision.utils.make_grid( - util.tensor_vis_landmarks(images[visind], landmarks3d[visind])) - grids['albedoimage'] = torchvision.utils.make_grid( - (ops['albedo_images'])[visind].detach().cpu()) - grids['render'] = torchvision.utils.make_grid(predicted_images[visind].detach().float().cpu()) - shape_images = self.render.render_shape(vertices, trans_vertices, images) - grids['shape'] = torchvision.utils.make_grid( - F.interpolate(shape_images[visind], [224, 224])).detach().float().cpu() - - - # grids['tex'] = torchvision.utils.make_grid(F.interpolate(albedos[visind], [224, 224])).detach().cpu() - grid = torch.cat(list(grids.values()), 1) - grid_image = (grid.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]] - grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8) - - cv2.imwrite('{}/{}.jpg'.format(savefolder, k), grid_image) - - single_params = { - 'shape': shape.detach().cpu().numpy(), - 'exp': exp.detach().cpu().numpy(), - 'pose': pose.detach().cpu().numpy(), - 'cam': cam.detach().cpu().numpy(), - 'verts': trans_vertices.detach().cpu().numpy(), - 'albedos':albedos.detach().cpu().numpy(), - 'tex': tex.detach().cpu().numpy(), - 'lit': lights.detach().cpu().numpy() - } - return single_params - - def run(self, imagepath, landmarkpath): - # The implementation is potentially able to optimize with images(batch_size>1), - # here we show the example with a single image fitting - images = [] - landmarks = [] - image_masks = [] - - image_name = os.path.basename(imagepath)[:-4] - savefile = os.path.sep.join([self.config.savefolder, image_name + '.npy']) - - # photometric optimization is sensitive to the hair or glass occlusions, - # therefore we use a face segmentation network to mask the skin region out. - image_mask_folder = './FFHQ_seg/' - image_mask_path = os.path.sep.join([image_mask_folder, image_name + '.npy']) - - image = cv2.resize(cv2.imread(imagepath), (config.cropped_size, config.cropped_size)).astype(np.float32) / 255. - image = image[:, :, [2, 1, 0]].transpose(2, 0, 1) - images.append(torch.from_numpy(image[None, :, :, :]).to(self.device)) - - image_mask = np.load(image_mask_path, allow_pickle=True) - image_mask = image_mask[..., None].astype('float32') - image_mask = image_mask.transpose(2, 0, 1) - image_mask_bn = np.zeros_like(image_mask) - image_mask_bn[np.where(image_mask != 0)] = 1. - image_masks.append(torch.from_numpy(image_mask_bn[None, :, :, :]).to(self.device)) - - landmark = np.load(landmarkpath).astype(np.float32) - landmark[:, 0] = landmark[:, 0] / float(image.shape[2]) * 2 - 1 - landmark[:, 1] = landmark[:, 1] / float(image.shape[1]) * 2 - 1 - landmarks.append(torch.from_numpy(landmark)[None, :, :].float().to(self.device)) - - images = torch.cat(images, dim=0) - images = F.interpolate(images, [self.image_size, self.image_size]) - image_masks = torch.cat(image_masks, dim=0) - image_masks = F.interpolate(image_masks, [self.image_size, self.image_size]) - - landmarks = torch.cat(landmarks, dim=0) - savefolder = os.path.sep.join([self.config.savefolder, image_name]) - - util.check_mkdir(savefolder) - # optimize - single_params = self.optimize(images, landmarks, image_masks, savefolder) - self.render.save_obj(filename=savefile[:-4]+'.obj', - vertices=torch.from_numpy(single_params['verts'][0]).to(self.device), - textures=torch.from_numpy(single_params['albedos'][0]).to(self.device) - ) - np.save(savefile, single_params) - - -if __name__ == '__main__': - image_name = str(sys.argv[1]) - device_name = str(sys.argv[2]) - config = { - # FLAME - 'flame_model_path': './data/generic_model.pkl', # acquire it from FLAME project page - 'flame_lmk_embedding_path': './data/landmark_embedding.npy', - 'tex_space_path': './data/FLAME_texture.npz', # acquire it from FLAME project page - 'camera_params': 3, - 'shape_params': 100, - 'expression_params': 50, - 'pose_params': 6, - 'tex_params': 50, - 'use_face_contour': True, - - 'cropped_size': 256, - 'batch_size': 1, - 'image_size': 224, - 'e_lr': 0.005, - 'e_wd': 0.0001, - 'savefolder': './test_results/', - # weights of losses and reg terms - 'w_pho': 8, - 'w_lmks': 1, - 'w_shape_reg': 1e-4, - 'w_expr_reg': 1e-4, - 'w_pose_reg': 0, - } - - config = util.dict2obj(config) - util.check_mkdir(config.savefolder) - - config.batch_size = 1 - fitting = PhotometricFitting(config, device=device_name) - - input_folder = './FFHQ' - - imagepath = os.path.sep.join([input_folder, image_name + '.png']) - landmarkpath = os.path.sep.join([input_folder, image_name + '.npy']) - fitting.run(imagepath, landmarkpath) +import os, sys +import cv2 +import torch +import torchvision +import torch.nn.functional as F +import torch.nn as nn +import numpy as np +import datetime + +sys.path.append('.') +from models.FLAME import FLAME, FLAMETex +from utils.renderer import Renderer +from utils import util +from utils.config import cfg + +torch.backends.cudnn.benchmark = True + + +class PhotometricFitting(object): + def __init__(self, device='cuda'): + self.batch_size = cfg.batch_size + self.image_size = cfg.image_size + self.cropped_size = cfg.cropped_size + self.config = cfg + self.device = device + self.flame = FLAME(self.config).to(self.device) + self.flametex = FLAMETex(self.config).to(self.device) + + self._setup_renderer() + + def _setup_renderer(self): + self.render = Renderer(self.image_size, obj_filename=cfg.mesh_file).to(self.device) + + def optimize(self, images, landmarks, image_masks, savefolder=None): + bz = images.shape[0] + shape = nn.Parameter(torch.zeros(bz, cfg.shape_params).float().to(self.device)) + tex = nn.Parameter(torch.zeros(bz, cfg.tex_params).float().to(self.device)) + exp = nn.Parameter(torch.zeros(bz, cfg.expression_params).float().to(self.device)) + pose = nn.Parameter(torch.zeros(bz, cfg.pose_params).float().to(self.device)) + cam = torch.zeros(bz, cfg.camera_params); cam[:, 0] = 5. + cam = nn.Parameter(cam.float().to(self.device)) + lights = nn.Parameter(torch.zeros(bz, 9, 3).float().to(self.device)) + e_opt = torch.optim.Adam( + [shape, exp, pose, cam, tex, lights], + lr=cfg.e_lr, + weight_decay=cfg.e_wd + ) + e_opt_rigid = torch.optim.Adam( + [pose, cam], + lr=cfg.e_lr, + weight_decay=cfg.e_wd + ) + + gt_landmark = landmarks + + # rigid fitting of pose and camera with 51 static face landmarks, + # this is due to the non-differentiable attribute of contour landmarks trajectory + for k in range(200): + losses = {} + vertices, landmarks2d, landmarks3d = self.flame(shape_params=shape, expression_params=exp, pose_params=pose) + trans_vertices = util.batch_orth_proj(vertices, cam) + trans_vertices[..., 1:] = - trans_vertices[..., 1:] + landmarks2d = util.batch_orth_proj(landmarks2d, cam) + landmarks2d[..., 1:] = - landmarks2d[..., 1:] + landmarks3d = util.batch_orth_proj(landmarks3d, cam) + landmarks3d[..., 1:] = - landmarks3d[..., 1:] + + losses['landmark'] = util.l2_distance(landmarks2d[:, 17:, :2], gt_landmark[:, 17:, :2]) * cfg.w_lmks + + all_loss = 0. + for key in losses.keys(): + all_loss = all_loss + losses[key] + losses['all_loss'] = all_loss + e_opt_rigid.zero_grad() + all_loss.backward() + e_opt_rigid.step() + + loss_info = '----iter: {}, time: {}\n'.format(k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')) + for key in losses.keys(): + loss_info = loss_info + '{}: {}, '.format(key, float(losses[key])) + if k % 10 == 0: + print(loss_info) + + if k % 10 == 0: + grids = {} + visind = range(bz) # [0] + grids['images'] = torchvision.utils.make_grid(images[visind]).detach().cpu() + grids['landmarks_gt'] = torchvision.utils.make_grid( + util.tensor_vis_landmarks(images[visind], landmarks[visind])) + grids['landmarks2d'] = torchvision.utils.make_grid( + util.tensor_vis_landmarks(images[visind], landmarks2d[visind])) + grids['landmarks3d'] = torchvision.utils.make_grid( + util.tensor_vis_landmarks(images[visind], landmarks3d[visind])) + + grid = torch.cat(list(grids.values()), 1) + grid_image = (grid.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]] + grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8) + cv2.imwrite('{}/{}.jpg'.format(savefolder, k), grid_image) + + # non-rigid fitting of all the parameters with 68 face landmarks, photometric loss and regularization terms. + for k in range(200, 1000): + losses = {} + vertices, landmarks2d, landmarks3d = self.flame(shape_params=shape, expression_params=exp, pose_params=pose) + trans_vertices = util.batch_orth_proj(vertices, cam) + trans_vertices[..., 1:] = - trans_vertices[..., 1:] + landmarks2d = util.batch_orth_proj(landmarks2d, cam) + landmarks2d[..., 1:] = - landmarks2d[..., 1:] + landmarks3d = util.batch_orth_proj(landmarks3d, cam) + landmarks3d[..., 1:] = - landmarks3d[..., 1:] + + losses['landmark'] = util.l2_distance(landmarks2d[:, :, :2], gt_landmark[:, :, :2]) * cfg.w_lmks + losses['shape_reg'] = (torch.sum(shape ** 2) / 2) * cfg.w_shape_reg # *1e-4 + losses['expression_reg'] = (torch.sum(exp ** 2) / 2) * cfg.w_expr_reg # *1e-4 + losses['pose_reg'] = (torch.sum(pose ** 2) / 2) * cfg.w_pose_reg + + ## render + albedos = self.flametex(tex) / 255. + ops = self.render(vertices, trans_vertices, albedos, lights) + predicted_images = ops['images'] + losses['photometric_texture'] = (image_masks * (ops['images'] - images).abs()).mean() * cfg.w_pho + + all_loss = 0. + for key in losses.keys(): + all_loss = all_loss + losses[key] + losses['all_loss'] = all_loss + e_opt.zero_grad() + all_loss.backward() + e_opt.step() + + loss_info = '----iter: {}, time: {}\n'.format(k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')) + for key in losses.keys(): + loss_info = loss_info + '{}: {}, '.format(key, float(losses[key])) + + if k % 10 == 0: + print(loss_info) + + # visualize + if k % 10 == 0: + grids = {} + visind = range(bz) # [0] + grids['images'] = torchvision.utils.make_grid(images[visind]).detach().cpu() + grids['landmarks_gt'] = torchvision.utils.make_grid( + util.tensor_vis_landmarks(images[visind], landmarks[visind])) + grids['landmarks2d'] = torchvision.utils.make_grid( + util.tensor_vis_landmarks(images[visind], landmarks2d[visind])) + grids['landmarks3d'] = torchvision.utils.make_grid( + util.tensor_vis_landmarks(images[visind], landmarks3d[visind])) + grids['albedoimage'] = torchvision.utils.make_grid( + (ops['albedo_images'])[visind].detach().cpu()) + grids['render'] = torchvision.utils.make_grid(predicted_images[visind].detach().float().cpu()) + shape_images = self.render.render_shape(vertices, trans_vertices, images) + grids['shape'] = torchvision.utils.make_grid( + F.interpolate(shape_images[visind], [224, 224])).detach().float().cpu() + + + # grids['tex'] = torchvision.utils.make_grid(F.interpolate(albedos[visind], [224, 224])).detach().cpu() + grid = torch.cat(list(grids.values()), 1) + grid_image = (grid.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]] + grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8) + + cv2.imwrite('{}/{}.jpg'.format(savefolder, k), grid_image) + + single_params = { + 'shape': shape.detach().cpu().numpy(), + 'exp': exp.detach().cpu().numpy(), + 'pose': pose.detach().cpu().numpy(), + 'cam': cam.detach().cpu().numpy(), + 'verts': trans_vertices.detach().cpu().numpy(), + 'albedos':albedos.detach().cpu().numpy(), + 'tex': tex.detach().cpu().numpy(), + 'lit': lights.detach().cpu().numpy() + } + return single_params + + def run(self, imagepath, landmarkpath): + # The implementation is potentially able to optimize with images(batch_size>1), + # here we show the example with a single image fitting + images = [] + landmarks = [] + image_masks = [] + + image_name = os.path.basename(imagepath)[:-4] + savefile = os.path.sep.join([cfg.save_folder, image_name + '.npy']) + + # photometric optimization is sensitive to the hair or glass occlusions, + # therefore we use a face segmentation network to mask the skin region out. + image_mask_folder = './FFHQ_seg/' + image_mask_path = os.path.sep.join([image_mask_folder, image_name + '.npy']) + + image = cv2.resize(cv2.imread(imagepath), (cfg.cropped_size, cfg.cropped_size)).astype(np.float32) / 255. + image = image[:, :, [2, 1, 0]].transpose(2, 0, 1) + images.append(torch.from_numpy(image[None, :, :, :]).to(self.device)) + + image_mask = np.load(image_mask_path, allow_pickle=True) + image_mask = image_mask[..., None].astype('float32') + image_mask = image_mask.transpose(2, 0, 1) + image_mask_bn = np.zeros_like(image_mask) + image_mask_bn[np.where(image_mask != 0)] = 1. + image_masks.append(torch.from_numpy(image_mask_bn[None, :, :, :]).to(self.device)) + + landmark = np.load(landmarkpath).astype(np.float32) + landmark[:, 0] = landmark[:, 0] / float(image.shape[2]) * 2 - 1 + landmark[:, 1] = landmark[:, 1] / float(image.shape[1]) * 2 - 1 + landmarks.append(torch.from_numpy(landmark)[None, :, :].float().to(self.device)) + + images = torch.cat(images, dim=0) + images = F.interpolate(images, [cfg.image_size, cfg.image_size]) + image_masks = torch.cat(image_masks, dim=0) + image_masks = F.interpolate(image_masks, [cfg.image_size, cfg.image_size]) + + landmarks = torch.cat(landmarks, dim=0) + savefolder = os.path.sep.join([cfg.save_folder, image_name]) + + util.check_mkdir(savefolder) + # optimize + single_params = self.optimize(images, landmarks, image_masks, savefolder) + self.render.save_obj(filename=savefile[:-4]+'.obj', + vertices=torch.from_numpy(single_params['verts'][0]).to(self.device), + textures=torch.from_numpy(single_params['albedos'][0]).to(self.device) + ) + np.save(savefile, single_params) + + +if __name__ == '__main__': + image_name = str(sys.argv[1]) + device_name = str(sys.argv[2]) + util.check_mkdir(cfg.save_folder) + + cfg.batch_size = 1 + fitting = PhotometricFitting(device=device_name) + + input_folder = './FFHQ' + + imagepath = os.path.sep.join([input_folder, image_name + '.png']) + landmarkpath = os.path.sep.join([input_folder, image_name + '.npy']) + fitting.run(imagepath, landmarkpath) diff --git a/demos/transfer_exp.py b/demos/transfer_exp.py new file mode 100644 index 00000000..a32cee82 --- /dev/null +++ b/demos/transfer_exp.py @@ -0,0 +1,52 @@ +import os +import sys +import cv2 +import torch +import torch.nn as nn +import numpy as np + +sys.path.append('.') +from models.face_seg_model import BiSeNet +from utils import util +from utils.config import cfg +from facial_alignment.detection import sfd_detector as detector +from facial_alignment.detection import FAN_landmark +from demos.exp_with_texture import PhotometricFitting + + +if __name__ == '__main__': + video_path = str(sys.argv[1]) + basic_model = str(sys.argv[2]) + device_name = str(sys.argv[3]) + util.check_mkdir(cfg.save_folder) + fitting = PhotometricFitting(device=device_name) + + # basic face parameter npy file + basic_face_data = np.load(basic_model, allow_pickle=True).item() + shape = nn.Parameter(torch.from_numpy(basic_face_data['shape']).float().to(device_name)) + tex = nn.Parameter(torch.from_numpy(basic_face_data['tex']).float().to(device_name)) + exp = nn.Parameter(torch.from_numpy(basic_face_data['exp']).float().to(device_name)) + pose = nn.Parameter(torch.from_numpy(basic_face_data['pose']).float().to(device_name)) + cam = nn.Parameter(torch.from_numpy(basic_face_data['cam']).float().to(device_name)) + lights = nn.Parameter(torch.from_numpy(basic_face_data['lit']).float().to(device_name)) + all_params = [shape, tex, exp, pose, cam, lights] + + save_video_name = os.path.split(video_path)[1].split(".")[0] + '.avi' + video_writer = cv2.VideoWriter(os.path.join(cfg.save_folder, save_video_name), + cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 16, + (cfg.image_size * 2, cfg.image_size)) + + cap = cv2.VideoCapture(video_path) + ret, frame = cap.read() + if ret: + w_h_scale = util.resize_para(frame) + face_detect = detector.SFDDetector(device_name, cfg.rect_model_path) + face_landmark = FAN_landmark.FANLandmarks(device_name, cfg.landmark_model_path, cfg.face_detect_type) + seg_net = BiSeNet(n_classes=cfg.seg_class) + seg_net.cuda() + seg_net.load_state_dict(torch.load(cfg.face_seg_model)) + seg_net.eval() + while ret: + all_params = fitting.run(frame, seg_net, face_detect, face_landmark, all_params, + cfg.rect_thresh, video_writer, False) + ret, frame = cap.read() diff --git a/demos/wj_fitting.py b/demos/wj_fitting.py new file mode 100644 index 00000000..5470942e --- /dev/null +++ b/demos/wj_fitting.py @@ -0,0 +1,195 @@ +import os, sys +import cv2 +import torch +import torchvision +import torch.nn.functional as F +import torch.nn as nn +import numpy as np +import datetime +sys.path.append('.') +from models.FLAME import FLAME, FLAMETex +from models.face_seg_model import BiSeNet +from utils.renderer import Renderer +from utils import util +from utils.config import cfg +from facial_alignment.detection import sfd_detector as detector +from facial_alignment.detection import FAN_landmark + +torch.backends.cudnn.benchmark = True + + +class PhotometricFitting(object): + def __init__(self, device='cuda'): + # self.batch_size = cfg.batch_size + # self.image_size = cfg.image_size + # self.cropped_size = cfg.cropped_size + self.config = cfg + self.device = device + self.flame = FLAME(self.config).to(self.device) + self.flametex = FLAMETex(self.config).to(self.device) + + self._setup_renderer() + + def _setup_renderer(self): + self.render = Renderer(cfg.image_size, obj_filename=cfg.mesh_file).to(self.device) + + def optimize(self, images, landmarks, image_masks, video_writer): + bz = images.shape[0] + shape = nn.Parameter(torch.zeros(bz, cfg.shape_params).float().to(self.device)) + tex = nn.Parameter(torch.zeros(bz, cfg.tex_params).float().to(self.device)) + exp = nn.Parameter(torch.zeros(bz, cfg.expression_params).float().to(self.device)) + pose = nn.Parameter(torch.zeros(bz, cfg.pose_params).float().to(self.device)) + cam = torch.zeros(bz, cfg.camera_params) + cam[:, 0] = 5. + cam = nn.Parameter(cam.float().to(self.device)) + lights = nn.Parameter(torch.zeros(bz, 9, 3).float().to(self.device)) + e_opt = torch.optim.Adam( + [shape, exp, pose, cam, tex, lights], + lr=cfg.e_lr, + weight_decay=cfg.e_wd + ) + + gt_landmark = landmarks + + # non-rigid fitting of all the parameters with 68 face landmarks, photometric loss and regularization terms. + all_train_iter = 0 + all_train_iters = [] + photometric_loss = [] + for k in range(cfg.max_iter): + losses = {} + vertices, landmarks2d, landmarks3d = self.flame(shape_params=shape, expression_params=exp, pose_params=pose) + trans_vertices = util.batch_orth_proj(vertices, cam) + trans_vertices[..., 1:] = - trans_vertices[..., 1:] + landmarks2d = util.batch_orth_proj(landmarks2d, cam) + landmarks2d[..., 1:] = - landmarks2d[..., 1:] + landmarks3d = util.batch_orth_proj(landmarks3d, cam) + landmarks3d[..., 1:] = - landmarks3d[..., 1:] + losses['landmark'] = util.l2_distance(landmarks2d[:, :, :2], gt_landmark[:, :, :2]) + + # render + albedos = self.flametex(tex) / 255. + ops = self.render(vertices, trans_vertices, albedos, lights) + predicted_images = ops['images'] + # losses['photometric_texture'] = (image_masks * (ops['images'] - images).abs()).mean() * config.w_pho + losses['photometric_texture'] = F.smooth_l1_loss(image_masks * ops['images'], + image_masks * images) * cfg.w_pho + + all_loss = 0. + for key in losses.keys(): + all_loss = all_loss + losses[key] + losses['all_loss'] = all_loss + e_opt.zero_grad() + all_loss.backward() + e_opt.step() + + loss_info = '----iter: {}, time: {}\n'.format(k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')) + for key in losses.keys(): + loss_info = loss_info + '{}: {}, '.format(key, float(losses[key])) + + if k % 10 == 0: + all_train_iter += 10 + all_train_iters.append(all_train_iter) + photometric_loss.append(losses['photometric_texture']) + print(loss_info) + + grids = {} + visind = range(bz) # [0] + grids['images'] = torchvision.utils.make_grid(images[visind]).detach().cpu() + grids['landmarks_gt'] = torchvision.utils.make_grid( + util.tensor_vis_landmarks(images[visind], landmarks[visind])) + grids['landmarks2d'] = torchvision.utils.make_grid( + util.tensor_vis_landmarks(images[visind], landmarks2d[visind])) + grids['landmarks3d'] = torchvision.utils.make_grid( + util.tensor_vis_landmarks(images[visind], landmarks3d[visind])) + grids['albedoimage'] = torchvision.utils.make_grid( + (ops['albedo_images'])[visind].detach().cpu()) + grids['render'] = torchvision.utils.make_grid(predicted_images[visind].detach().float().cpu()) + shape_images = self.render.render_shape(vertices, trans_vertices, images) + grids['shape'] = torchvision.utils.make_grid( + F.interpolate(shape_images[visind], [224, 224])).detach().float().cpu() + + # grids['tex'] = torchvision.utils.make_grid(F.interpolate(albedos[visind], [224, 224])).detach().cpu() + grid = torch.cat(list(grids.values()), 1) + grid_image = (grid.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]] + grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8) + video_writer.write(grid_image) + + single_params = { + 'shape': shape.detach().cpu().numpy(), + 'exp': exp.detach().cpu().numpy(), + 'pose': pose.detach().cpu().numpy(), + 'cam': cam.detach().cpu().numpy(), + 'verts': trans_vertices.detach().cpu().numpy(), + 'albedos': albedos.detach().cpu().numpy(), + 'tex': tex.detach().cpu().numpy(), + 'lit': lights.detach().cpu().numpy() + } + util.draw_train_process("training", all_train_iters, photometric_loss, 'photometric loss') + # np.save("./test_results/model.npy", single_params) + return single_params + + def run(self, img, net, rect_detect, landmark_detect, rect_thresh, save_name, video_writer, savefolder): + # The implementation is potentially able to optimize with images(batch_size>1), + # here we show the example with a single image fitting + images = [] + landmarks = [] + image_masks = [] + bbox = rect_detect.extract(img, rect_thresh) + if len(bbox) > 0: + crop_image, new_bbox = util.crop_img(img, bbox[0], cfg.cropped_size) + + # input landmark + resize_img, landmark = landmark_detect.extract([crop_image, [new_bbox]]) + landmark = landmark[0] + landmark[:, 0] = landmark[:, 0] / float(resize_img.shape[1]) * 2 - 1 + landmark[:, 1] = landmark[:, 1] / float(resize_img.shape[0]) * 2 - 1 + landmarks.append(torch.from_numpy(landmark)[None, :, :].double().to(self.device)) + landmarks = torch.cat(landmarks, dim=0) + + # input image + image = cv2.resize(crop_image, (cfg.cropped_size, cfg.cropped_size)).astype(np.float32) / 255. + image = image[:, :, [2, 1, 0]].transpose(2, 0, 1) + images.append(torch.from_numpy(image[None, :, :, :]).double().to(self.device)) + images = torch.cat(images, dim=0) + images = F.interpolate(images, [cfg.image_size, cfg.image_size]) + + # face segment mask + image_mask = util.face_seg(crop_image, net, cfg.cropped_size) + image_masks.append(torch.from_numpy(image_mask).double().to(cfg.device)) + image_masks = torch.cat(image_masks, dim=0) + image_masks = F.interpolate(image_masks, [cfg.image_size, cfg.image_size]) + + # check folder exist or not + util.check_mkdir(savefolder) + save_file = os.path.join(savefolder, save_name) + + # optimize + single_params = self.optimize(images, landmarks, image_masks, video_writer) + self.render.save_obj(filename=save_file, + vertices=torch.from_numpy(single_params['verts'][0]).to(self.device), + textures=torch.from_numpy(single_params['albedos'][0]).to(self.device) + ) + np.save(save_file, single_params) + + +if __name__ == '__main__': + image_path = str(sys.argv[1]) + device_name = str(sys.argv[2]) + + save_name = os.path.split(image_path)[1].split(".")[0] + '.obj' + save_video_name = os.path.split(image_path)[1].split(".")[0] + '.avi' + video_writer = cv2.VideoWriter(os.path.join(cfg.save_folder, save_video_name), + cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 16, + (cfg.image_size, cfg.image_size * 7)) + util.check_mkdir(cfg.save_folder) + fitting = PhotometricFitting(device=device_name) + img = cv2.imread(image_path) + + face_detect = detector.SFDDetector(device_name, cfg.rect_model_path) + face_landmark = FAN_landmark.FANLandmarks(device_name, cfg.landmark_model_path, cfg.face_detect_type) + + seg_net = BiSeNet(n_classes=cfg.seg_class).cuda() + seg_net.load_state_dict(torch.load(cfg.face_seg_model)) + seg_net.eval() + fitting.run(img, seg_net, face_detect, face_landmark, cfg.rect_thresh, save_name, video_writer, + cfg.save_folder) diff --git a/facial_alignment/__init__.py b/facial_alignment/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/facial_alignment/detection/FAN_landmark.py b/facial_alignment/detection/FAN_landmark.py new file mode 100644 index 00000000..2834bf03 --- /dev/null +++ b/facial_alignment/detection/FAN_landmark.py @@ -0,0 +1,63 @@ +from facial_alignment.detection.models import FAN, ResNetDepth +from .utils import crop, get_preds_fromhm, draw_gaussian +import torch +import numpy as np +import cv2 + + +class FANLandmarks: + def __init__(self, device, model_path, detect_type): + # Initialise the face detector + model_weights = torch.load(model_path) + self.device = device + self.detect_type = detect_type + torch.backends.cudnn.benchmark = True + self.face_landmark = FAN(4) + self.face_landmark.load_state_dict(model_weights) + self.face_landmark.to(device) + self.face_landmark.eval() + self.reference_scale = 195.0 + + if self.detect_type == "3D": + self.depth_prediciton_net = ResNetDepth() + depth_weights = torch.load("D:/model/depth-2a464da4ea.pth.tar") + depth_dict = {k.replace('module.', ''): v for k, v in depth_weights['state_dict'].items()} + self.depth_prediciton_net.load_state_dict(depth_dict) + self.depth_prediciton_net.to(device) + self.depth_prediciton_net.eval() + + def extract(self, rect_queue): + # image, face_rect = rect_queue.get(block=True, timeout=10) + image, face_rect = rect_queue + landmarks = [] + for i, d in enumerate(face_rect): + center_x = d[2] - (d[2] - d[0]) / 2.0 + center_y = d[3] - (d[3] - d[1]) / 2.0 + center = torch.FloatTensor([center_x, center_y]) + scale = (d[2] - d[0] + d[3] - d[1]) / self.reference_scale + + inp = crop(image, center, scale) + inp = torch.from_numpy(inp.transpose((2, 0, 1))).float().to(self.device) + inp.div_(255.0).unsqueeze_(0) + with torch.no_grad(): + out = self.face_landmark(inp)[-1] + out = out.cpu() + pts, pts_img = get_preds_fromhm(out, center, scale) + pts, pts_img = pts.view(68, 2) * 4, pts_img.view(68, 2) + + if self.detect_type == "3D": + heatmaps = np.zeros((68, 256, 256), dtype=np.float32) + for i in range(68): + if pts[i, 0] > 0: + heatmaps[i] = draw_gaussian(heatmaps[i], pts[i], 2) + heatmaps = torch.from_numpy(heatmaps).unsqueeze_(0) + + heatmaps = heatmaps.to(self.device) + depth_pred = self.depth_prediciton_net(torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1) + pts_img = torch.cat((pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1) + + landmarks.append(pts_img.numpy()) + + return image, landmarks + + diff --git a/facial_alignment/detection/__init__.py b/facial_alignment/detection/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/facial_alignment/detection/bbox.py b/facial_alignment/detection/bbox.py new file mode 100644 index 00000000..c8f6d5e2 --- /dev/null +++ b/facial_alignment/detection/bbox.py @@ -0,0 +1,102 @@ +from __future__ import print_function +import math +import numpy as np +import torch + +try: + from iou import IOU +except BaseException: + # IOU cython speedup 10x + def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2): + sa = abs((ax2 - ax1) * (ay2 - ay1)) + sb = abs((bx2 - bx1) * (by2 - by1)) + x1, y1 = max(ax1, bx1), max(ay1, by1) + x2, y2 = min(ax2, bx2), min(ay2, by2) + w = x2 - x1 + h = y2 - y1 + if w < 0 or h < 0: + return 0.0 + else: + return 1.0 * w * h / (sa + sb - w * h) + + +def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh): + xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1 + dx, dy = (xc - axc) / aww, (yc - ayc) / ahh + dw, dh = math.log(ww / aww), math.log(hh / ahh) + return dx, dy, dw, dh + + +def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh): + xc, yc = dx * aww + axc, dy * ahh + ayc + ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh + x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2 + return x1, y1, x2, y2 + + +def nms(dets, thresh): + if 0 == len(dets): + return [] + x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) + xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) + + w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) + ovr = w * h / (areas[i] + areas[order[1:]] - w * h) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes diff --git a/facial_alignment/detection/models.py b/facial_alignment/detection/models.py new file mode 100644 index 00000000..90cc59c9 --- /dev/null +++ b/facial_alignment/detection/models.py @@ -0,0 +1,260 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=strd, padding=padding, bias=bias) + + +class ConvBlock(nn.Module): + def __init__(self, in_planes, out_planes): + super(ConvBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = conv3x3(in_planes, int(out_planes / 2)) + self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) + self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) + self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) + self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) + + if in_planes != out_planes: + self.downsample = nn.Sequential( + nn.BatchNorm2d(in_planes), + nn.ReLU(True), + nn.Conv2d(in_planes, out_planes, + kernel_size=1, stride=1, bias=False), + ) + else: + self.downsample = None + + def forward(self, x): + residual = x + + out1 = self.bn1(x) + out1 = F.relu(out1, True) + out1 = self.conv1(out1) + + out2 = self.bn2(out1) + out2 = F.relu(out2, True) + out2 = self.conv2(out2) + + out3 = self.bn3(out2) + out3 = F.relu(out3, True) + out3 = self.conv3(out3) + + out3 = torch.cat((out1, out2, out3), 1) + + if self.downsample is not None: + residual = self.downsample(residual) + + out3 += residual + + return out3 + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HourGlass(nn.Module): + def __init__(self, num_modules, depth, num_features): + super(HourGlass, self).__init__() + self.num_modules = num_modules + self.depth = depth + self.features = num_features + + self._generate_network(self.depth) + + def _generate_network(self, level): + self.add_module('b1_' + str(level), ConvBlock(self.features, self.features)) + + self.add_module('b2_' + str(level), ConvBlock(self.features, self.features)) + + if level > 1: + self._generate_network(level - 1) + else: + self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features)) + + self.add_module('b3_' + str(level), ConvBlock(self.features, self.features)) + + def _forward(self, level, inp): + # Upper branch + up1 = inp + up1 = self._modules['b1_' + str(level)](up1) + + # Lower branch + low1 = F.avg_pool2d(inp, 2, stride=2) + low1 = self._modules['b2_' + str(level)](low1) + + if level > 1: + low2 = self._forward(level - 1, low1) + else: + low2 = low1 + low2 = self._modules['b2_plus_' + str(level)](low2) + + low3 = low2 + low3 = self._modules['b3_' + str(level)](low3) + + up2 = F.interpolate(low3, scale_factor=2, mode='nearest') + + return up1 + up2 + + def forward(self, x): + return self._forward(self.depth, x) + + +class FAN(nn.Module): + + def __init__(self, num_modules=1): + super(FAN, self).__init__() + self.num_modules = num_modules + + # Base part + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(64) + self.conv2 = ConvBlock(64, 128) + self.conv3 = ConvBlock(128, 128) + self.conv4 = ConvBlock(128, 256) + + # Stacking part + for hg_module in range(self.num_modules): + self.add_module('m' + str(hg_module), HourGlass(1, 4, 256)) + self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) + self.add_module('conv_last' + str(hg_module), + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) + self.add_module('l' + str(hg_module), nn.Conv2d(256, + 68, kernel_size=1, stride=1, padding=0)) + + if hg_module < self.num_modules - 1: + self.add_module( + 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('al' + str(hg_module), nn.Conv2d(68, + 256, kernel_size=1, stride=1, padding=0)) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x)), True) + x = F.avg_pool2d(self.conv2(x), 2, stride=2) + x = self.conv3(x) + x = self.conv4(x) + + previous = x + + outputs = [] + for i in range(self.num_modules): + hg = self._modules['m' + str(i)](previous) + + ll = hg + ll = self._modules['top_m_' + str(i)](ll) + + ll = F.relu(self._modules['bn_end' + str(i)] + (self._modules['conv_last' + str(i)](ll)), True) + + # Predict heatmaps + tmp_out = self._modules['l' + str(i)](ll) + outputs.append(tmp_out) + + if i < self.num_modules - 1: + ll = self._modules['bl' + str(i)](ll) + tmp_out_ = self._modules['al' + str(i)](tmp_out) + previous = previous + ll + tmp_out_ + + return outputs + + +class ResNetDepth(nn.Module): + + def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68): + self.inplanes = 64 + super(ResNetDepth, self).__init__() + self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x diff --git a/facial_alignment/detection/net_s3fd.py b/facial_alignment/detection/net_s3fd.py new file mode 100644 index 00000000..46e1e01d --- /dev/null +++ b/facial_alignment/detection/net_s3fd.py @@ -0,0 +1,134 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class L2Norm(nn.Module): + def __init__(self, n_channels, scale=1.0): + super(L2Norm, self).__init__() + self.n_channels = n_channels + self.scale = scale + self.eps = 1e-10 + self.weight = nn.Parameter(torch.Tensor(self.n_channels)) + self.weight.data *= 0.0 + self.weight.data += self.scale + + def forward(self, x): + norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps + x = x / norm * self.weight.view(1, -1, 1, 1) + return x + + +class s3fd(nn.Module): + def __init__(self): + super(s3fd, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) + self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) + + self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) + self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) + + self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) + self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) + + self.conv3_3_norm = L2Norm(256, scale=10) + self.conv4_3_norm = L2Norm(512, scale=8) + self.conv5_3_norm = L2Norm(512, scale=5) + + self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + + self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) + self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + h = F.relu(self.conv1_1(x)) + h = F.relu(self.conv1_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv2_1(h)) + h = F.relu(self.conv2_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv3_1(h)) + h = F.relu(self.conv3_2(h)) + h = F.relu(self.conv3_3(h)) + f3_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv4_1(h)) + h = F.relu(self.conv4_2(h)) + h = F.relu(self.conv4_3(h)) + f4_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv5_1(h)) + h = F.relu(self.conv5_2(h)) + h = F.relu(self.conv5_3(h)) + f5_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.fc6(h)) + h = F.relu(self.fc7(h)) + ffc7 = h + h = F.relu(self.conv6_1(h)) + h = F.relu(self.conv6_2(h)) + f6_2 = h + h = F.relu(self.conv7_1(h)) + h = F.relu(self.conv7_2(h)) + f7_2 = h + + f3_3 = self.conv3_3_norm(f3_3) + f4_3 = self.conv4_3_norm(f4_3) + f5_3 = self.conv5_3_norm(f5_3) + + cls1 = self.conv3_3_norm_mbox_conf(f3_3) + reg1 = self.conv3_3_norm_mbox_loc(f3_3) + cls2 = self.conv4_3_norm_mbox_conf(f4_3) + reg2 = self.conv4_3_norm_mbox_loc(f4_3) + cls3 = self.conv5_3_norm_mbox_conf(f5_3) + reg3 = self.conv5_3_norm_mbox_loc(f5_3) + cls4 = self.fc7_mbox_conf(ffc7) + reg4 = self.fc7_mbox_loc(ffc7) + cls5 = self.conv6_2_mbox_conf(f6_2) + reg5 = self.conv6_2_mbox_loc(f6_2) + cls6 = self.conv7_2_mbox_conf(f7_2) + reg6 = self.conv7_2_mbox_loc(f7_2) + + # max-out background label + chunk = torch.chunk(cls1, 4, 1) + bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) + cls1 = torch.cat([bmax, chunk[3]], dim=1) + cls1 = F.softmax(cls1, dim=1) + cls2 = F.softmax(cls2, dim=1) + cls3 = F.softmax(cls3, dim=1) + cls4 = F.softmax(cls4, dim=1) + cls5 = F.softmax(cls5, dim=1) + cls6 = F.softmax(cls6, dim=1) + return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] diff --git a/facial_alignment/detection/sfd_detector.py b/facial_alignment/detection/sfd_detector.py new file mode 100644 index 00000000..91200a5b --- /dev/null +++ b/facial_alignment/detection/sfd_detector.py @@ -0,0 +1,68 @@ +from .net_s3fd import s3fd +from .bbox import nms, decode +import torch.nn.functional as F +import numpy as np +import cv2 +import torch + + +class SFDDetector: + def __init__(self, device, model_path): + # Initialise the face detector + model_weights = torch.load(model_path) + torch.backends.cudnn.benchmark = True + self.device = device + self.face_detector = s3fd().to(self.device) + self.face_detector.load_state_dict(model_weights) + self.face_detector.eval() + + def pre_process_frame(self, frame): + img = frame[..., ::-1] + img = img - np.array([104, 117, 123]) + img = img.transpose(2, 0, 1) + img = img.reshape((1,) + img.shape) + return img + + def detect_rect(self, frame, thresh): + img = self.pre_process_frame(frame) + img = torch.from_numpy(img).float().to(self.device) + with torch.no_grad(): + olist = self.face_detector(img) + + bboxes = [] + olist = [oelem.data.cpu() for oelem in olist] + for i in range(len(olist) // 2): + ocls, oreg = olist[i * 2], olist[i * 2 + 1] + stride = 2 ** (i + 2) # 4,8,16,32,64,128 + poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) + for Iindex, hindex, windex in poss: + axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride + score = ocls[0, 1, hindex, windex] + if score > thresh: + loc = oreg[0, :, hindex, windex].contiguous().view(1, 4) + priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) + variances = [0.1, 0.2] + box = decode(loc, priors, variances) + x1, y1, x2, y2 = box[0] * 1.0 + bboxes.append([x1, y1, x2, y2, score]) + bboxes = np.array(bboxes) + return bboxes + + def extract(self, frame, thresh): + + bboxes = self.detect_rect(frame, thresh) + if len(bboxes) > 0: + keep = nms(bboxes, 0.3) + bboxlist = bboxes[keep, :] + # restore the rect points + detected_faces = [] + for ltrb in bboxlist: + l, t, r, b, _ = ltrb + bt = b - t + if min(r - l, bt) < 40: + continue + b += bt * 0.1 + detected_faces.append((l, t, r, b)) + else: + return [] + return detected_faces diff --git a/facial_alignment/detection/utils.py b/facial_alignment/detection/utils.py new file mode 100644 index 00000000..e8a665d6 --- /dev/null +++ b/facial_alignment/detection/utils.py @@ -0,0 +1,293 @@ +from __future__ import print_function +import os +import sys +import torch +import math +import numpy as np +import cv2 + + +def _gaussian( + size=3, sigma=0.25, amplitude=1, normalize=False, width=None, + height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5, + mean_vert=0.5): + # handle some defaults + if width is None: + width = size + if height is None: + height = size + if sigma_horz is None: + sigma_horz = sigma + if sigma_vert is None: + sigma_vert = sigma + center_x = mean_horz * width + 0.5 + center_y = mean_vert * height + 0.5 + gauss = np.empty((height, width), dtype=np.float32) + # generate kernel + for i in range(height): + for j in range(width): + gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / ( + sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0)) + if normalize: + gauss = gauss / np.sum(gauss) + return gauss + + +def draw_gaussian(image, point, sigma): + # Check if the gaussian is inside + ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)] + br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)] + if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1): + return image + size = 6 * sigma + 1 + g = _gaussian(size) + g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))] + g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))] + img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))] + img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))] + assert (g_x[0] > 0 and g_y[1] > 0) + image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1] + ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]] + image[image > 1] = 1 + return image + + +def transform(point, center, scale, resolution, invert=False): + """Generate and affine transformation matrix. + Given a set of points, a center, a scale and a targer resolution, the + function generates and affine transformation matrix. If invert is ``True`` + it will produce the inverse transformation. + Arguments: + point {torch.tensor} -- the input 2D point + center {torch.tensor or numpy.array} -- the center around which to perform the transformations + scale {float} -- the scale of the face/object + resolution {float} -- the output resolution + Keyword Arguments: + invert {bool} -- define wherever the function should produce the direct or the + inverse transformation matrix (default: {False}) + """ + _pt = torch.ones(3) + _pt[0] = point[0] + _pt[1] = point[1] + + h = 200.0 * scale + t = torch.eye(3) + t[0, 0] = resolution / h + t[1, 1] = resolution / h + t[0, 2] = resolution * (-center[0] / h + 0.5) + t[1, 2] = resolution * (-center[1] / h + 0.5) + + if invert: + t = torch.inverse(t) + + new_point = (torch.matmul(t, _pt))[0:2] + + return new_point.int() + + +def crop(image, center, scale, resolution=256.0): + """Center crops an image or set of heatmaps + Arguments: + image {numpy.array} -- an rgb image + center {numpy.array} -- the center of the object, usually the same as of the bounding box + scale {float} -- scale of the face + Keyword Arguments: + resolution {float} -- the size of the output cropped image (default: {256.0}) + Returns: + [type] -- [description] + """ # Crop around the center point + """ Crops the image around the center. Input is expected to be an np.ndarray """ + ul = transform([1, 1], center, scale, resolution, True) + br = transform([resolution, resolution], center, scale, resolution, True) + # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0) + if image.ndim > 2: + newDim = np.array([br[1] - ul[1], br[0] - ul[0], + image.shape[2]], dtype=np.int32) + newImg = np.zeros(newDim, dtype=np.uint8) + else: + newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int) + newImg = np.zeros(newDim, dtype=np.uint8) + ht = image.shape[0] + wd = image.shape[1] + newX = np.array( + [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32) + newY = np.array( + [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32) + oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32) + oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32) + newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] + ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :] + newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), + interpolation=cv2.INTER_LINEAR) + return newImg + + +def get_preds_fromhm(hm, center=None, scale=None): + """Obtain (x,y) coordinates given a set of N heatmaps. If the center + and the scale is provided the function will return the points also in + the original coordinate frame. + Arguments: + hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] + Keyword Arguments: + center {torch.tensor} -- the center of the bounding box (default: {None}) + scale {float} -- face scale (default: {None}) + """ + max, idx = torch.max( + hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) + idx += 1 + preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() + preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) + preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) + + for i in range(preds.size(0)): + for j in range(preds.size(1)): + hm_ = hm[i, j, :] + pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 + if pX > 0 and pX < 63 and pY > 0 and pY < 63: + diff = torch.FloatTensor( + [hm_[pY, pX + 1] - hm_[pY, pX - 1], + hm_[pY + 1, pX] - hm_[pY - 1, pX]]) + preds[i, j].add_(diff.sign_().mul_(.25)) + + preds.add_(-.5) + + preds_orig = torch.zeros(preds.size()) + if center is not None and scale is not None: + for i in range(hm.size(0)): + for j in range(hm.size(1)): + preds_orig[i, j] = transform( + preds[i, j], center, scale, hm.size(2), True) + + return preds, preds_orig + + +def create_target_heatmap(target_landmarks, centers, scales): + heatmaps = np.zeros((target_landmarks.shape[0], 68, 64, 64), dtype=np.float32) + for i in range(heatmaps.shape[0]): + for p in range(68): + landmark_cropped_coor = transform(target_landmarks[i, p] + 1, centers[i], scales[i], 64, invert=False) + heatmaps[i, p] = draw_gaussian(heatmaps[i, p], landmark_cropped_coor + 1, 1) + return torch.tensor(heatmaps) + + +def create_bounding_box(target_landmarks, expansion_factor=0.0): + """ + gets a batch of landmarks and calculates a bounding box that includes all the landmarks per set of landmarks in + the batch + :param target_landmarks: batch of landmarks of dim (n x 68 x 2). Where n is the batch size + :param expansion_factor: expands the bounding box by this factor. For example, a `expansion_factor` of 0.2 leads + to 20% increase in width and height of the boxes + :return: a batch of bounding boxes of dim (n x 4) where the second dim is (x1,y1,x2,y2) + """ + # Calc bounding box + x_y_min, _ = target_landmarks.reshape(-1, 68, 2).min(dim=1) + x_y_max, _ = target_landmarks.reshape(-1, 68, 2).max(dim=1) + # expanding the bounding box + expansion_factor /= 2 + bb_expansion_x = (x_y_max[:, 0] - x_y_min[:, 0]) * expansion_factor + bb_expansion_y = (x_y_max[:, 1] - x_y_min[:, 1]) * expansion_factor + x_y_min[:, 0] -= bb_expansion_x + x_y_max[:, 0] += bb_expansion_x + x_y_min[:, 1] -= bb_expansion_y + x_y_max[:, 1] += bb_expansion_y + return torch.cat([x_y_min, x_y_max], dim=1) + + +def shuffle_lr(parts, pairs=None): + """Shuffle the points left-right according to the axis of symmetry + of the object. + Arguments: + parts {torch.tensor} -- a 3D or 4D object containing the + heatmaps. + Keyword Arguments: + pairs {list of integers} -- [order of the flipped points] (default: {None}) + """ + if pairs is None: + pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, + 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, + 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63, + 62, 61, 60, 67, 66, 65] + if parts.ndimension() == 3: + parts = parts[pairs, ...] + else: + parts = parts[:, pairs, ...] + + return parts + + +def flip(tensor, is_label=False): + """Flip an image or a set of heatmaps left-right + Arguments: + tensor {numpy.array or torch.tensor} -- [the input image or heatmaps] + Keyword Arguments: + is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False}) + """ + if not torch.is_tensor(tensor): + tensor = torch.from_numpy(tensor) + + if is_label: + tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1) + else: + tensor = tensor.flip(tensor.ndimension() - 1) + + return tensor + + +# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py) + + +def appdata_dir(appname=None, roaming=False): + """ appdata_dir(appname=None, roaming=False) + Get the path to the application directory, where applications are allowed + to write user specific files (e.g. configurations). For non-user specific + data, consider using common_appdata_dir(). + If appname is given, a subdir is appended (and created if necessary). + If roaming is True, will prefer a roaming directory (Windows Vista/7). + """ + + # Define default user directory + userDir = os.getenv('FACEALIGNMENT_USERDIR', None) + if userDir is None: + userDir = os.path.expanduser('~') + if not os.path.isdir(userDir): # pragma: no cover + userDir = '/var/tmp' # issue #54 + + # Get system app data dir + path = None + if sys.platform.startswith('win'): + path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA') + path = (path2 or path1) if roaming else (path1 or path2) + elif sys.platform.startswith('darwin'): + path = os.path.join(userDir, 'Library', 'Application Support') + # On Linux and as fallback + if not (path and os.path.isdir(path)): + path = userDir + + # Maybe we should store things local to the executable (in case of a + # portable distro or a frozen application that wants to be portable) + prefix = sys.prefix + if getattr(sys, 'frozen', None): + prefix = os.path.abspath(os.path.dirname(sys.executable)) + for reldir in ('settings', '../settings'): + localpath = os.path.abspath(os.path.join(prefix, reldir)) + if os.path.isdir(localpath): # pragma: no cover + try: + open(os.path.join(localpath, 'test.write'), 'wb').close() + os.remove(os.path.join(localpath, 'test.write')) + except IOError: + pass # We cannot write in this directory + else: + path = localpath + break + + # Get path specific for this app + if appname: + if path == userDir: + appname = '.' + appname.lstrip('.') # Make it a hidden directory + path = os.path.join(path, appname) + if not os.path.isdir(path): # pragma: no cover + os.mkdir(path) + + # Done + return path \ No newline at end of file diff --git a/facial_alignment/face_detect.py b/facial_alignment/face_detect.py new file mode 100644 index 00000000..2b0c8f15 --- /dev/null +++ b/facial_alignment/face_detect.py @@ -0,0 +1,189 @@ +from facial_alignment.detection import sfd_detector as detector +from facial_alignment.detection import FAN_landmark +import cv2 +import torch +import socket +from TG_thread.FullProcess import FullSwapProcess +from collections import deque +import time +from head_pose_estimation.pose_estimator import PoseEstimator +from head_pose_estimation.stabilizer import Stabilizer +from head_pose_estimation.visualization import * +from head_pose_estimation.misc import * +import multiprocessing as mp +from DFLIMG import * +from pathlib import Path + + +def resize_para(ori_frame): + w, h, c = ori_frame.shape + d = max(w, h) + scale_to = 640 if d >= 1280 else d / 2 + scale_to = max(64, scale_to) + input_scale = d / scale_to + w = int(w / input_scale) + h = int(h / input_scale) + image_info = [w, h, input_scale] + return image_info + + +def cv2_imwrite(filename, img, *args): + ret, buf = cv2.imencode(Path(filename).suffix, img, *args) + if ret == True: + try: + with open(filename, "wb") as stream: + stream.write(buf) + except: + pass + + +video_path = "E:/data/face_video/20200810hsc.mp4" +rect_model_path = "D:/model/s3fd.pth" +landmark_model_path = "D:/model/2DFAN4-11f355bf06.pth.tar" +if __name__ == "__main__": + # mp.set_start_method("spawn") + # ctx = mp.get_context("spawn") + # FullSwapProcess(ctx, video_path, 'cuda', 0.5).start() + cap = cv2.VideoCapture(video_path) + ret, frame = cap.read() + # frame = cv2.imread("D:/data/face.png") + # frame = cv2.resize(frame, (640, 360)) + + w_h_scale = resize_para(frame) + device = torch.device("cuda") + face_detect = detector.SFDDetector(device, rect_model_path, w_h_scale) + face_landmark = FAN_landmark.FANLandmarks(device, landmark_model_path) + thresh = 0.5 + + # pose_estimator = PoseEstimator(img_size=frame.shape[:2]) + # pose_stabilizers = [Stabilizer( + # state_num=2, + # measure_num=1, + # cov_process=0.01, + # cov_measure=0.1) for _ in range(8)] + # address = ('127.0.0.1', 5066) + # s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # s.connect(address) + frame_count = 0 + ts = [] + no_face_count = 0 + prev_boxes = deque(maxlen=5) + prev_marks = deque(maxlen=5) + + while True: + # frame_count += 1 + # flip_frame = cv2.flip(frame, 2) + # t = time.time() + # facebox = face_detect.extract(flip_frame, thresh) + # + # box = facebox[0] + # + # if len(box) > 0: + # # or every even frame + # box = [int(i) for i in box] + # face_img = flip_frame[box[1]: box[3], box[0]: box[2]] + # transfer_frame, landmarks = face_landmark.extract([flip_frame, facebox]) + # marks = landmarks[-1] + # + # # x_l, y_l, ll, lu = detect_iris(frame, marks, "left") + # # x_r, y_r, rl, ru = detect_iris(frame, marks, "right") + # pose = pose_estimator.solve_pose_by_68_points(marks) + # + # # pose_estimator.draw_annotation_box( + # # frame, pose[0], pose[1], color=(128, 255, 128)) + # + # steady_pose = [] + # pose_np = np.array(pose).flatten() + # for value, ps_stb in zip(pose_np, pose_stabilizers): + # ps_stb.update([value]) + # steady_pose.append(ps_stb.state[0]) + # # + # # roll = np.clip(-(180 + np.degrees(steady_pose[2])), -50, 50)[0] + # # pitch = np.clip(-(np.degrees(steady_pose[1])) - 15, -40, 40)[0] + # # yaw = np.clip(-(np.degrees(steady_pose[0])), -50, 50)[0] + # roll = float(np.degrees(steady_pose[2])) + # pitch = float(np.degrees(steady_pose[1])) + # yaw = float(np.degrees(steady_pose[0])) + # + # mouse_open_ratio = np.linalg.norm(marks[62] - marks[66]) / np.linalg.norm(marks[60] - marks[64]) * 2 + # print(mouse_open_ratio) + # # if frame_count > 60: # send information to unity + # # msg = '%.4f %.4f %.4f %.4f' % \ + # # (roll, pitch, yaw, mouse_open_ratio) + # # s.send(bytes(msg, "utf-8")) + # cv2.imshow("Preview", frame) + # if cv2.waitKey(1) == 27: + # break + # ret, frame = cap.read() + # frame = cv2.resize(frame, (640, 360)) + + total_start = cv2.getTickCount() + rect_start = cv2.getTickCount() + bbox = face_detect.extract(frame, thresh) + rect_end = cv2.getTickCount() + rect_time = (rect_end - rect_start) / cv2.getTickFrequency() * 1000 + print("rect time: ", rect_time, "ms...") + if len(bbox) > 0: + frame, landmarks = face_landmark.extract([frame, bbox]) + total_end = cv2.getTickCount() + total_time = (total_end - total_start) / cv2.getTickFrequency() * 1000 + print("total time: ", total_time, "ms...") + for land in landmarks: + # max_x, max_y = land.astype("int32").max(axis=0) + # min_x, min_y = land.astype("int32").min(axis=0) + # center_x = (max_x + min_x) // 2 + # center_y = (max_y + min_y) // 2 + # w = int((max_x - min_x) * 1.6) + # h = int((max_y - min_y) * 2.2) + # x1 = center_x - h // 2 + # x2 = x1 + h + # y1 = center_y - h // 3 * 2 + # y2 = y1 + h + # cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 255)) + for (x, y) in land: + cv2.circle(frame, (x, y), 2, (0, 0, 255), -1) + img_h, img_w, _ = frame.shape + ratio = 1.6 + if img_w == 1280: + ratio = 1.4 + elif img_w == 3840: + ratio = 1.8 + if len(bbox) > 0: + for box in bbox: + box = [int(i) for i in box] + x, y, x2, y2 = box + center_x = (x + x2) // 2 + center_y = (y + y2) // 2 + h = int((y2 - y) * ratio) + if center_x < h / 2: + h = center_x + if center_y < h / 2: + h = center_y + if center_x + h / 2 > img_w: + h = (img_w - center_x) * 2 + if center_y + h / 2 > img_h: + h = (img_h - center_y) * 2 + + x1 = max(center_x - h // 2, 0) + x2 = min(x1 + h, img_w) + y1 = max(center_y - h // 3 * 2, 0) + y2 = min(y1 + h, img_h) + cut_frame = frame[y1:y2, x1:x2, :] + cut_frame = cv2.resize(cut_frame, (768, 768)) + cv2.imshow("cut_frame", cut_frame) + cv2.waitKey(1) + frame_count += 1 + output_filepath = "E:/data/hsc/" + str(frame_count).zfill(5) + ".jpg" + # cv2_imwrite(output_filepath, cut_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) + # cv2.imwrite(output_filepath, cut_frame) + import os + if os.path.isfile(output_filepath): + dflimg = DFLJPG.load(output_filepath) + dflimg.set_source_rect((x1, y1, x2, y2)) + + dflimg.save() + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 255)) + cv2.imshow("frame", frame) + cv2.waitKey(1) + ret, frame = cap.read() + # frame = cv2.resize(frame, (640, 360)) diff --git a/gif_helper.py b/gif_helper.py index 146cbf36..af1d6279 100644 --- a/gif_helper.py +++ b/gif_helper.py @@ -1,21 +1,11 @@ -import os, sys -import cv2 +import sys import torch -import torchvision -import torch.nn.functional as F -import torch.nn as nn -import numpy as np -from pytorch3d.io import load_obj - -from glob import glob -import time -import datetime -import imageio sys.path.append('./models/') from FLAME import FLAME, FLAMETex -from renderer import Renderer -import util +from utils.renderer import Renderer +from utils import util + torch.backends.cudnn.benchmark = True diff --git a/models/FLAME.py b/models/FLAME.py index 6b618f72..50ec7303 100644 --- a/models/FLAME.py +++ b/models/FLAME.py @@ -10,7 +10,8 @@ import pickle import torch.nn.functional as F -from lbs import lbs, batch_rodrigues, vertices2landmarks +from models.lbs import lbs, batch_rodrigues, vertices2landmarks + def to_tensor(array, dtype=torch.float32): if 'torch.tensor' not in str(type(array)): @@ -206,12 +207,12 @@ def forward(self, shape_params=None, expression_params=None, pose_params=None, e lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1) landmarks2d = vertices2landmarks(vertices, self.faces_tensor, - lmk_faces_idx, - lmk_bary_coords) + lmk_faces_idx, + lmk_bary_coords) bz = vertices.shape[0] landmarks3d = vertices2landmarks(vertices, self.faces_tensor, - self.full_lmk_faces_idx.repeat(bz, 1), - self.full_lmk_bary_coords.repeat(bz, 1, 1)) + self.full_lmk_faces_idx.repeat(bz, 1), + self.full_lmk_bary_coords.repeat(bz, 1, 1)) return vertices, landmarks2d, landmarks3d diff --git a/models/__init__.py b/models/__init__.py index d5149e8b..8b137891 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,2 +1 @@ -from . import FLAME -from . import lbs + diff --git a/models/face_seg_model.py b/models/face_seg_model.py new file mode 100644 index 00000000..971a0157 --- /dev/null +++ b/models/face_seg_model.py @@ -0,0 +1,281 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.resnet import Resnet18 + + +class ConvBNReLU(nn.Module): + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, + out_chan, + kernel_size = ks, + stride = stride, + padding = padding, + bias = False) + self.bn = nn.BatchNorm2d(out_chan) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + +class BiSeNetOutput(nn.Module): + def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = self.conv_out(x) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class AttentionRefinementModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + self.init_weight() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + +class ContextPath(nn.Module): + def __init__(self, *args, **kwargs): + super(ContextPath, self).__init__() + self.resnet = Resnet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + self.init_weight() + + def forward(self, x): + H0, W0 = x.size()[2:] + feat8, feat16, feat32 = self.resnet(x) + H8, W8 = feat8.size()[2:] + H16, W16 = feat16.size()[2:] + H32, W32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (H32, W32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +### This is not used, since I replace this with the resnet feature with the same size +class SpatialPath(nn.Module): + def __init__(self, *args, **kwargs): + super(SpatialPath, self).__init__() + self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) + self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) + self.init_weight() + + def forward(self, x): + feat = self.conv1(x) + feat = self.conv2(feat) + feat = self.conv3(feat) + feat = self.conv_out(feat) + return feat + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class FeatureFusionModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, + out_chan//4, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.conv2 = nn.Conv2d(out_chan//4, + out_chan, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + self.init_weight() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class BiSeNet(nn.Module): + def __init__(self, n_classes, *args, **kwargs): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + ## here self.sp is deleted + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, n_classes) + self.conv_out16 = BiSeNetOutput(128, 64, n_classes) + self.conv_out32 = BiSeNetOutput(128, 64, n_classes) + self.init_weight() + + def forward(self, x): + H, W = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature + feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + feat_out = self.conv_out(feat_fuse) + feat_out16 = self.conv_out16(feat_cp8) + feat_out32 = self.conv_out32(feat_cp16) + + feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) + feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) + feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) + return feat_out, feat_out16, feat_out32 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] + for name, child in self.named_children(): + child_wd_params, child_nowd_params = child.get_params() + if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): + lr_mul_wd_params += child_wd_params + lr_mul_nowd_params += child_nowd_params + else: + wd_params += child_wd_params + nowd_params += child_nowd_params + return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params + + +if __name__ == "__main__": + net = BiSeNet(19) + net.cuda() + net.eval() + in_ten = torch.randn(16, 3, 640, 480).cuda() + out, out16, out32 = net(in_ten) + print(out.shape) + + net.get_params() \ No newline at end of file diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 00000000..295d1407 --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,109 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as modelzoo + +# from modules.bn import InPlaceABNSync as BatchNorm2d + +resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum-1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class Resnet18(nn.Module): + def __init__(self): + super(Resnet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + self.init_weight() + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 + + def init_weight(self): + state_dict = modelzoo.load_url(resnet18_url) + self_state_dict = self.state_dict() + for k, v in state_dict.items(): + if 'fc' in k: continue + self_state_dict.update({k: v}) + self.load_state_dict(self_state_dict) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +if __name__ == "__main__": + net = Resnet18() + x = torch.randn(16, 3, 224, 224) + out = net(x) + print(out[0].size()) + print(out[1].size()) + print(out[2].size()) + net.get_params() \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 00000000..d5f27101 --- /dev/null +++ b/utils/config.py @@ -0,0 +1,41 @@ +from yacs.config import CfgNode +import os + +cfg = CfgNode() + +abs_root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +cfg.root_dir = abs_root_dir +cfg.device = 'cuda' +cfg.device_id = '0' +cfg.face_detect_type = "2D" +cfg.flame_model_path = os.path.join(cfg.root_dir, 'model', 'generic_model.pkl') +cfg.tex_space_path = os.path.join(cfg.root_dir, 'model', 'FLAME_texture.npz') +cfg.flame_model_path = os.path.join(cfg.root_dir, 'model', 'generic_model.pkl') +cfg.rect_model_path = os.path.join(cfg.root_dir, 'model', 's3fd.pth') +cfg.face_seg_model = os.path.join(cfg.root_dir, 'model', 'face_seg.pth') +cfg.landmark_model_path = os.path.join(cfg.root_dir, 'model', '2DFAN4-11f355bf06.pth.tar') +cfg.flame_lmk_embedding_path = os.path.join(cfg.root_dir, 'data', 'landmark_embedding.npy') +cfg.mesh_file = os.path.join(cfg.root_dir, 'data', 'head_template_mesh.obj') +cfg.save_folder = os.path.join(cfg.root_dir, 'test_results') + +cfg.camera_params = 3 +cfg.shape_params = 100 +cfg.expression_params = 50 +cfg.pose_params = 6 +cfg.tex_params = 50 +cfg.seg_class = 19 +cfg.use_face_contour = True +cfg.cropped_size = 256 +cfg.batch_size = 1 +cfg.image_size = 224 +cfg.rect_thresh = 0.5 +cfg.e_lr = 0.005 +cfg.e_wd = 0.0001 +cfg.w_pho = 8 +cfg.w_lmks = 1 +cfg.max_iter = 2000 +cfg.w_shape_reg = 1e-4 +cfg.w_expr_reg = 1e-4 +cfg.w_pose_reg = 0 + + diff --git a/renderer.py b/utils/renderer.py similarity index 97% rename from renderer.py rename to utils/renderer.py index 81864155..c04d4908 100644 --- a/renderer.py +++ b/utils/renderer.py @@ -1,325 +1,324 @@ -""" -Author: Yao Feng -Copyright (c) 2020, Yao Feng -All rights reserved. -""" -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from skimage.io import imread -from pytorch3d.structures import Meshes -from pytorch3d.io import load_obj -from pytorch3d.renderer.mesh import rasterize_meshes -import util - - -class Pytorch3dRasterizer(nn.Module): - """ - This class implements methods for rasterizing a batch of heterogenous - Meshes. - - Notice: - x,y,z are in image space - """ - - def __init__(self, image_size=224): - """ - Args: - raster_settings: the parameters for rasterization. This should be a - named tuple. - All these initial settings can be overridden by passing keyword - arguments to the forward function. - """ - super().__init__() - raster_settings = { - 'image_size': image_size, - 'blur_radius': 0.0, - 'faces_per_pixel': 1, - 'bin_size': None, - 'max_faces_per_bin': None, - 'perspective_correct': False, - } - raster_settings = util.dict2obj(raster_settings) - self.raster_settings = raster_settings - - def forward(self, vertices, faces, attributes=None): - """ - Args: - meshes_world: a Meshes object representing a batch of meshes with - coordinates in world space. - Returns: - Fragments: Rasterization outputs as a named tuple. - """ - fixed_vetices = vertices.clone() - fixed_vetices[..., :2] = -fixed_vetices[..., :2] - meshes_screen = Meshes(verts=fixed_vetices.float(), faces=faces.long()) - raster_settings = self.raster_settings - - pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( - meshes_screen, - image_size=raster_settings.image_size, - blur_radius=raster_settings.blur_radius, - faces_per_pixel=raster_settings.faces_per_pixel, - bin_size=raster_settings.bin_size, - max_faces_per_bin=raster_settings.max_faces_per_bin, - perspective_correct=raster_settings.perspective_correct, - ) - - vismask = (pix_to_face > -1).float() - D = attributes.shape[-1] - attributes = attributes.clone() - attributes = attributes.view(attributes.shape[0] * attributes.shape[1], 3, attributes.shape[-1]) - N, H, W, K, _ = bary_coords.shape - mask = pix_to_face == -1 # [] - pix_to_face = pix_to_face.clone() - pix_to_face[mask] = 0 - idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) - pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D) - pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2) - pixel_vals[mask] = 0 # Replace masked values in output. - pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2) - pixel_vals = torch.cat([pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) - # import ipdb; ipdb.set_trace() - return pixel_vals - - -class Renderer(nn.Module): - def __init__(self, image_size, obj_filename, uv_size=256): - super(Renderer, self).__init__() - self.image_size = image_size - self.uv_size = uv_size - - verts, faces, aux = load_obj(obj_filename) - uvcoords = aux.verts_uvs[None, ...] # (N, V, 2) - uvfaces = faces.textures_idx[None, ...] # (N, F, 3) - faces = faces.verts_idx[None, ...] - self.rasterizer = Pytorch3dRasterizer(image_size) - self.uv_rasterizer = Pytorch3dRasterizer(uv_size) - - # faces - self.register_buffer('faces', faces) - self.register_buffer('raw_uvcoords', uvcoords) - - # uv coordsw - uvcoords = torch.cat([uvcoords, uvcoords[:, :, 0:1] * 0. + 1.], -1) # [bz, ntv, 3] - uvcoords = uvcoords * 2 - 1 - uvcoords[..., 1] = -uvcoords[..., 1] - face_uvcoords = util.face_vertices(uvcoords, uvfaces) - self.register_buffer('uvcoords', uvcoords) - self.register_buffer('uvfaces', uvfaces) - self.register_buffer('face_uvcoords', face_uvcoords) - - # shape colors - colors = torch.tensor([74, 120, 168])[None, None, :].repeat(1, faces.max() + 1, 1).float() / 255. - face_colors = util.face_vertices(colors, faces) - self.register_buffer('face_colors', face_colors) - - ## lighting - pi = np.pi - constant_factor = torch.tensor( - [1 / np.sqrt(4 * pi), ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), \ - ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), - (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), \ - (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), (pi / 4) * (3 / 2) * (np.sqrt(5 / (12 * pi))), - (pi / 4) * (1 / 2) * (np.sqrt(5 / (4 * pi)))]) - self.register_buffer('constant_factor', constant_factor) - - - - def forward(self, vertices, transformed_vertices, albedos, lights=None, light_type='point'): - ''' - lihgts: - spherical homarnic: [N, 9(shcoeff), 3(rgb)] - vertices: [N, V, 3], vertices in work space, for calculating normals, then shading - transformed_vertices: [N, V, 3], range(-1, 1), projected vertices, for rendering - ''' - batch_size = vertices.shape[0] - ## rasterizer near 0 far 100. move mesh so minz larger than 0 - transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] + 10 - - # Attributes - face_vertices = util.face_vertices(vertices, self.faces.expand(batch_size, -1, -1)) - normals = util.vertex_normals(vertices, self.faces.expand(batch_size, -1, -1)) - face_normals = util.face_vertices(normals, self.faces.expand(batch_size, -1, -1)) - transformed_normals = util.vertex_normals(transformed_vertices, self.faces.expand(batch_size, -1, -1)) - transformed_face_normals = util.face_vertices(transformed_normals, self.faces.expand(batch_size, -1, -1)) - - # render - attributes = torch.cat([self.face_uvcoords.expand(batch_size, -1, -1, -1), transformed_face_normals.detach(), - face_vertices.detach(), face_normals.detach()], -1) - # import ipdb;ipdb.set_trace() - rendering = self.rasterizer(transformed_vertices, self.faces.expand(batch_size, -1, -1), attributes) - - alpha_images = rendering[:, -1, :, :][:, None, :, :].detach() - - # albedo - uvcoords_images = rendering[:, :3, :, :] - grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2] - - albedo_images = F.grid_sample(albedos, grid, align_corners=False) - - # remove inner mouth region - transformed_normal_map = rendering[:, 3:6, :, :].detach() - pos_mask = (transformed_normal_map[:, 2:, :, :] < -0.05).float() - - # shading - if lights is not None: - normal_images = rendering[:, 9:12, :, :].detach() - if lights.shape[1] == 9: - shading_images = self.add_SHlight(normal_images, lights) - else: - if light_type == 'point': - vertice_images = rendering[:, 6:9, :, :].detach() - shading = self.add_pointlight(vertice_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), - normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), - lights) - shading_images = shading.reshape( - [batch_size, lights.shape[1], albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0, 1, - 4, 2, - 3) - shading_images = shading_images.mean(1) - else: - shading = self.add_directionlight(normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), - lights) - shading_images = shading.reshape( - [batch_size, lights.shape[1], albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0, 1, - 4, 2, - 3) - shading_images = shading_images.mean(1) - images = albedo_images * shading_images - else: - images = albedo_images - shading_images = images.detach() * 0. - - outputs = { - 'images': images * alpha_images, - 'albedo_images': albedo_images, - 'alpha_images': alpha_images, - 'pos_mask': pos_mask, - 'shading_images': shading_images, - 'grid': grid, - 'normals': normals - } - - return outputs - - def add_SHlight(self, normal_images, sh_coeff): - ''' - sh_coeff: [bz, 9, 3] - ''' - N = normal_images - sh = torch.stack([ - N[:, 0] * 0. + 1., N[:, 0], N[:, 1], \ - N[:, 2], N[:, 0] * N[:, 1], N[:, 0] * N[:, 2], - N[:, 1] * N[:, 2], N[:, 0] ** 2 - N[:, 1] ** 2, 3 * (N[:, 2] ** 2) - 1 - ], - 1) # [bz, 9, h, w] - sh = sh * self.constant_factor[None, :, None, None] - # import ipdb; ipdb.set_trace() - shading = torch.sum(sh_coeff[:, :, :, None, None] * sh[:, :, None, :, :], 1) # [bz, 9, 3, h, w] - return shading - - def add_pointlight(self, vertices, normals, lights): - ''' - vertices: [bz, nv, 3] - lights: [bz, nlight, 6] - returns: - shading: [bz, nv, 3] - ''' - light_positions = lights[:,:,:3]; light_intensities = lights[:,:,3:] - directions_to_lights = F.normalize(light_positions[:,:,None,:] - vertices[:,None,:,:], dim=3) - # normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.) - normals_dot_lights = (normals[:,None,:,:]*directions_to_lights).sum(dim=3) - shading = normals_dot_lights[:,:,:,None]*light_intensities[:,:,None,:] - return shading.mean(1) - - def add_directionlight(self, normals, lights): - ''' - normals: [bz, nv, 3] - lights: [bz, nlight, 6] - returns: - shading: [bz, nlgiht, nv, 3] - ''' - light_direction = lights[:, :, :3]; - light_intensities = lights[:, :, 3:] - directions_to_lights = F.normalize(light_direction[:, :, None, :].expand(-1, -1, normals.shape[1], -1), dim=3) - normals_dot_lights = (normals[:,None,:,:]*directions_to_lights).sum(dim=3) - shading = normals_dot_lights[:, :, :, None] * light_intensities[:, :, None, :] - return shading - - def render_shape(self, vertices, transformed_vertices, images=None, lights=None): - batch_size = vertices.shape[0] - if lights is None: - light_positions = torch.tensor([[-0.1, -0.1, 0.2], - [0, 0, 1]] - )[None, :, :].expand(batch_size, -1, -1).float() - light_intensities = torch.ones_like(light_positions).float() - lights = torch.cat((light_positions, light_intensities), 2).to(vertices.device) - - ## rasterizer near 0 far 100. move mesh so minz larger than 0 - transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] + 10 - - # Attributes - face_vertices = util.face_vertices(vertices, self.faces.expand(batch_size, -1, -1)) - normals = util.vertex_normals(vertices, self.faces.expand(batch_size, -1, -1)); - face_normals = util.face_vertices(normals, self.faces.expand(batch_size, -1, -1)) - transformed_normals = util.vertex_normals(transformed_vertices, self.faces.expand(batch_size, -1, -1)); - transformed_face_normals = util.face_vertices(transformed_normals, self.faces.expand(batch_size, -1, -1)) - # render - attributes = torch.cat( - [self.face_colors.expand(batch_size, -1, -1, -1), transformed_face_normals.detach(), face_vertices.detach(), - face_normals.detach()], -1) - rendering = self.rasterizer(transformed_vertices, self.faces.expand(batch_size, -1, -1), attributes) - # albedo - albedo_images = rendering[:, :3, :, :] - # shading - normal_images = rendering[:, 9:12, :, :].detach() - if lights.shape[1] == 9: - shading_images = self.add_SHlight(normal_images, lights) - else: - print('directional') - shading = self.add_directionlight(normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), lights) - - shading_images = shading.reshape( - [batch_size, lights.shape[1], albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0, 1, 4, 2, 3) - shading_images = shading_images.mean(1) - images = albedo_images * shading_images - - return images - - def render_normal(self, transformed_vertices, normals): - ''' - -- rendering normal - ''' - batch_size = normals.shape[0] - - # Attributes - attributes = util.face_vertices(normals, self.faces.expand(batch_size, -1, -1)) - # rasterize - rendering = self.rasterizer(transformed_vertices, self.faces.expand(batch_size, -1, -1), attributes) - - #### - alpha_images = rendering[:, -1, :, :][:, None, :, :].detach() - normal_images = rendering[:, :3, :, :] - return normal_images - - def world2uv(self, vertices): - ''' - sample vertices from world space to uv space - uv_vertices: [bz, 3, h, w] - ''' - batch_size = vertices.shape[0] - face_vertices = util.face_vertices(vertices, self.faces.expand(batch_size, -1, -1)).clone().detach() - uv_vertices = self.uv_rasterizer(self.uvcoords.expand(batch_size, -1, -1), - self.uvfaces.expand(batch_size, -1, -1), face_vertices)[:, :3] - - return uv_vertices - - def save_obj(self, filename, vertices, textures): - ''' - vertices: [nv, 3], tensor - texture: [3, h, w], tensor - ''' - util.save_obj(filename, vertices, self.faces[0], textures=textures, uvcoords=self.raw_uvcoords[0], - uvfaces=self.uvfaces[0]) \ No newline at end of file +""" +Author: Yao Feng +Copyright (c) 2020, Yao Feng +All rights reserved. +""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from pytorch3d.structures import Meshes +from pytorch3d.io import load_obj +from pytorch3d.renderer.mesh import rasterize_meshes +from utils import util + + +class Pytorch3dRasterizer(nn.Module): + """ + This class implements methods for rasterizing a batch of heterogenous + Meshes. + + Notice: + x,y,z are in image space + """ + + def __init__(self, image_size=224): + """ + Args: + raster_settings: the parameters for rasterization. This should be a + named tuple. + All these initial settings can be overridden by passing keyword + arguments to the forward function. + """ + super().__init__() + raster_settings = { + 'image_size': image_size, + 'blur_radius': 0.0, + 'faces_per_pixel': 1, + 'bin_size': None, + 'max_faces_per_bin': None, + 'perspective_correct': False, + } + raster_settings = util.dict2obj(raster_settings) + self.raster_settings = raster_settings + + def forward(self, vertices, faces, attributes=None): + """ + Args: + meshes_world: a Meshes object representing a batch of meshes with + coordinates in world space. + Returns: + Fragments: Rasterization outputs as a named tuple. + """ + fixed_vetices = vertices.clone() + fixed_vetices[..., :2] = -fixed_vetices[..., :2] + meshes_screen = Meshes(verts=fixed_vetices.float(), faces=faces.long()) + raster_settings = self.raster_settings + + pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( + meshes_screen, + image_size=raster_settings.image_size, + blur_radius=raster_settings.blur_radius, + faces_per_pixel=raster_settings.faces_per_pixel, + bin_size=raster_settings.bin_size, + max_faces_per_bin=raster_settings.max_faces_per_bin, + perspective_correct=raster_settings.perspective_correct, + ) + + vismask = (pix_to_face > -1).float() + D = attributes.shape[-1] + attributes = attributes.clone() + attributes = attributes.view(attributes.shape[0] * attributes.shape[1], 3, attributes.shape[-1]) + N, H, W, K, _ = bary_coords.shape + mask = pix_to_face == -1 # [] + pix_to_face = pix_to_face.clone() + pix_to_face[mask] = 0 + idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) + pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D) + pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2) + pixel_vals[mask] = 0 # Replace masked values in output. + pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2) + pixel_vals = torch.cat([pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) + # import ipdb; ipdb.set_trace() + return pixel_vals + + +class Renderer(nn.Module): + def __init__(self, image_size, obj_filename, uv_size=256): + super(Renderer, self).__init__() + self.image_size = image_size + self.uv_size = uv_size + + verts, faces, aux = load_obj(obj_filename) + uvcoords = aux.verts_uvs[None, ...] # (N, V, 2) + uvfaces = faces.textures_idx[None, ...] # (N, F, 3) + faces = faces.verts_idx[None, ...] + self.rasterizer = Pytorch3dRasterizer(image_size) + self.uv_rasterizer = Pytorch3dRasterizer(uv_size) + + # faces + self.register_buffer('faces', faces) + self.register_buffer('raw_uvcoords', uvcoords) + + # uv coordsw + uvcoords = torch.cat([uvcoords, uvcoords[:, :, 0:1] * 0. + 1.], -1) # [bz, ntv, 3] + uvcoords = uvcoords * 2 - 1 + uvcoords[..., 1] = -uvcoords[..., 1] + face_uvcoords = util.face_vertices(uvcoords, uvfaces) + self.register_buffer('uvcoords', uvcoords) + self.register_buffer('uvfaces', uvfaces) + self.register_buffer('face_uvcoords', face_uvcoords) + + # shape colors + colors = torch.tensor([74, 120, 168])[None, None, :].repeat(1, faces.max() + 1, 1).float() / 255. + face_colors = util.face_vertices(colors, faces) + self.register_buffer('face_colors', face_colors) + + ## lighting + pi = np.pi + constant_factor = torch.tensor( + [1 / np.sqrt(4 * pi), ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), \ + ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), + (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), \ + (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), (pi / 4) * (3 / 2) * (np.sqrt(5 / (12 * pi))), + (pi / 4) * (1 / 2) * (np.sqrt(5 / (4 * pi)))]) + self.register_buffer('constant_factor', constant_factor) + + + + def forward(self, vertices, transformed_vertices, albedos, lights=None, light_type='point'): + ''' + lihgts: + spherical homarnic: [N, 9(shcoeff), 3(rgb)] + vertices: [N, V, 3], vertices in work space, for calculating normals, then shading + transformed_vertices: [N, V, 3], range(-1, 1), projected vertices, for rendering + ''' + batch_size = vertices.shape[0] + ## rasterizer near 0 far 100. move mesh so minz larger than 0 + transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] + 10 + + # Attributes + face_vertices = util.face_vertices(vertices, self.faces.expand(batch_size, -1, -1)) + normals = util.vertex_normals(vertices, self.faces.expand(batch_size, -1, -1)) + face_normals = util.face_vertices(normals, self.faces.expand(batch_size, -1, -1)) + transformed_normals = util.vertex_normals(transformed_vertices, self.faces.expand(batch_size, -1, -1)) + transformed_face_normals = util.face_vertices(transformed_normals, self.faces.expand(batch_size, -1, -1)) + + # render + attributes = torch.cat([self.face_uvcoords.expand(batch_size, -1, -1, -1), transformed_face_normals.detach(), + face_vertices.detach(), face_normals.detach()], -1) + # import ipdb;ipdb.set_trace() + rendering = self.rasterizer(transformed_vertices, self.faces.expand(batch_size, -1, -1), attributes) + + alpha_images = rendering[:, -1, :, :][:, None, :, :].detach() + + # albedo + uvcoords_images = rendering[:, :3, :, :] + grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2] + + albedo_images = F.grid_sample(albedos, grid, align_corners=False) + + # remove inner mouth region + transformed_normal_map = rendering[:, 3:6, :, :].detach() + pos_mask = (transformed_normal_map[:, 2:, :, :] < -0.05).float() + + # shading + if lights is not None: + normal_images = rendering[:, 9:12, :, :].detach() + if lights.shape[1] == 9: + shading_images = self.add_SHlight(normal_images, lights) + else: + if light_type == 'point': + vertice_images = rendering[:, 6:9, :, :].detach() + shading = self.add_pointlight(vertice_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), + normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), + lights) + shading_images = shading.reshape( + [batch_size, lights.shape[1], albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0, 1, + 4, 2, + 3) + shading_images = shading_images.mean(1) + else: + shading = self.add_directionlight(normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), + lights) + shading_images = shading.reshape( + [batch_size, lights.shape[1], albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0, 1, + 4, 2, + 3) + shading_images = shading_images.mean(1) + images = albedo_images * shading_images + else: + images = albedo_images + shading_images = images.detach() * 0. + + outputs = { + 'images': images * alpha_images, + 'albedo_images': albedo_images, + 'alpha_images': alpha_images, + 'pos_mask': pos_mask, + 'shading_images': shading_images, + 'grid': grid, + 'normals': normals + } + + return outputs + + def add_SHlight(self, normal_images, sh_coeff): + ''' + sh_coeff: [bz, 9, 3] + ''' + N = normal_images + sh = torch.stack([ + N[:, 0] * 0. + 1., N[:, 0], N[:, 1], \ + N[:, 2], N[:, 0] * N[:, 1], N[:, 0] * N[:, 2], + N[:, 1] * N[:, 2], N[:, 0] ** 2 - N[:, 1] ** 2, 3 * (N[:, 2] ** 2) - 1 + ], + 1) # [bz, 9, h, w] + sh = sh * self.constant_factor[None, :, None, None] + # import ipdb; ipdb.set_trace() + shading = torch.sum(sh_coeff[:, :, :, None, None] * sh[:, :, None, :, :], 1) # [bz, 9, 3, h, w] + return shading + + def add_pointlight(self, vertices, normals, lights): + ''' + vertices: [bz, nv, 3] + lights: [bz, nlight, 6] + returns: + shading: [bz, nv, 3] + ''' + light_positions = lights[:,:,:3]; light_intensities = lights[:,:,3:] + directions_to_lights = F.normalize(light_positions[:,:,None,:] - vertices[:,None,:,:], dim=3) + # normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.) + normals_dot_lights = (normals[:,None,:,:]*directions_to_lights).sum(dim=3) + shading = normals_dot_lights[:,:,:,None]*light_intensities[:,:,None,:] + return shading.mean(1) + + def add_directionlight(self, normals, lights): + ''' + normals: [bz, nv, 3] + lights: [bz, nlight, 6] + returns: + shading: [bz, nlgiht, nv, 3] + ''' + light_direction = lights[:, :, :3]; + light_intensities = lights[:, :, 3:] + directions_to_lights = F.normalize(light_direction[:, :, None, :].expand(-1, -1, normals.shape[1], -1), dim=3) + normals_dot_lights = (normals[:,None,:,:]*directions_to_lights).sum(dim=3) + shading = normals_dot_lights[:, :, :, None] * light_intensities[:, :, None, :] + return shading + + def render_shape(self, vertices, transformed_vertices, images=None, lights=None): + batch_size = vertices.shape[0] + if lights is None: + light_positions = torch.tensor([[-0.1, -0.1, 0.2], + [0, 0, 1]] + )[None, :, :].expand(batch_size, -1, -1).float() + light_intensities = torch.ones_like(light_positions).float() + lights = torch.cat((light_positions, light_intensities), 2).to(vertices.device) + + ## rasterizer near 0 far 100. move mesh so minz larger than 0 + transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] + 10 + + # Attributes + face_vertices = util.face_vertices(vertices, self.faces.expand(batch_size, -1, -1)) + normals = util.vertex_normals(vertices, self.faces.expand(batch_size, -1, -1)); + face_normals = util.face_vertices(normals, self.faces.expand(batch_size, -1, -1)) + transformed_normals = util.vertex_normals(transformed_vertices, self.faces.expand(batch_size, -1, -1)); + transformed_face_normals = util.face_vertices(transformed_normals, self.faces.expand(batch_size, -1, -1)) + # render + attributes = torch.cat( + [self.face_colors.expand(batch_size, -1, -1, -1), transformed_face_normals.detach(), face_vertices.detach(), + face_normals.detach()], -1) + rendering = self.rasterizer(transformed_vertices, self.faces.expand(batch_size, -1, -1), attributes) + # albedo + albedo_images = rendering[:, :3, :, :] + # shading + normal_images = rendering[:, 9:12, :, :].detach() + if lights.shape[1] == 9: + shading_images = self.add_SHlight(normal_images, lights) + else: + print('directional') + shading = self.add_directionlight(normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), lights) + + shading_images = shading.reshape( + [batch_size, lights.shape[1], albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0, 1, 4, 2, 3) + shading_images = shading_images.mean(1) + images = albedo_images * shading_images + + return images + + def render_normal(self, transformed_vertices, normals): + ''' + -- rendering normal + ''' + batch_size = normals.shape[0] + + # Attributes + attributes = util.face_vertices(normals, self.faces.expand(batch_size, -1, -1)) + # rasterize + rendering = self.rasterizer(transformed_vertices, self.faces.expand(batch_size, -1, -1), attributes) + + #### + alpha_images = rendering[:, -1, :, :][:, None, :, :].detach() + normal_images = rendering[:, :3, :, :] + return normal_images + + def world2uv(self, vertices): + ''' + sample vertices from world space to uv space + uv_vertices: [bz, 3, h, w] + ''' + batch_size = vertices.shape[0] + face_vertices = util.face_vertices(vertices, self.faces.expand(batch_size, -1, -1)).clone().detach() + uv_vertices = self.uv_rasterizer(self.uvcoords.expand(batch_size, -1, -1), + self.uvfaces.expand(batch_size, -1, -1), face_vertices)[:, :3] + + return uv_vertices + + def save_obj(self, filename, vertices, textures): + ''' + vertices: [nv, 3], tensor + texture: [3, h, w], tensor + ''' + util.save_obj(filename, vertices, self.faces[0], textures=textures, uvcoords=self.raw_uvcoords[0], + uvfaces=self.uvfaces[0]) \ No newline at end of file diff --git a/util.py b/utils/util.py similarity index 77% rename from util.py rename to utils/util.py index 50f9c3b4..b8ace579 100644 --- a/util.py +++ b/utils/util.py @@ -1,313 +1,389 @@ -import numpy as np -import torch -import torch.nn.functional as F -import math -from collections import OrderedDict -import os -from scipy.ndimage import morphology -from skimage.io import imsave -import cv2 - - -def dict2obj(d): - if isinstance(d, list): - d = [dict2obj(x) for x in d] - if not isinstance(d, dict): - return d - - class C(object): - pass - - o = C() - for k in d: - o.__dict__[k] = dict2obj(d[k]) - return o - - -def check_mkdir(path): - if not os.path.exists(path): - print('making %s' % path) - os.makedirs(path) - - -def l2_distance(verts1, verts2): - return torch.sqrt(((verts1 - verts2) ** 2).sum(2)).mean(1).mean() - - -def quat2mat(quat): - """Convert quaternion coefficients to rotation matrix. - Args: - quat: size = [B, 4] 4 <===>(w, x, y, z) - Returns: - Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] - """ - norm_quat = quat - norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) - w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] - - B = quat.size(0) - - w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) - wx, wy, wz = w * x, w * y, w * z - xy, xz, yz = x * y, x * z, y * z - - rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, - 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx, - 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) - return rotMat - - -def batch_rodrigues(theta): - # theta N x 3 - batch_size = theta.shape[0] - l1norm = torch.norm(theta + 1e-8, p=2, dim=1) - angle = torch.unsqueeze(l1norm, -1) - normalized = torch.div(theta, angle) - angle = angle * 0.5 - v_cos = torch.cos(angle) - v_sin = torch.sin(angle) - quat = torch.cat([v_cos, v_sin * normalized], dim=1) - - return quat2mat(quat) - - -def batch_orth_proj(X, camera): - ''' - X is N x num_points x 3 - ''' - camera = camera.clone().view(-1, 1, 3) - X_trans = X[:, :, :2] + camera[:, :, 1:] - X_trans = torch.cat([X_trans, X[:, :, 2:]], 2) - shape = X_trans.shape - # Xn = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape) - Xn = (camera[:, :, 0:1] * X_trans) - return Xn - - -def batch_persp_proj(vertices, cam, f, t, orig_size=256, eps=1e-9): - ''' - Calculate projective transformation of vertices given a projection matrix - Input parameters: - f: torch tensor of focal length - t: batch_size * 1 * 3 xyz translation in world coordinate - K: batch_size * 3 * 3 intrinsic camera matrix - R, t: batch_size * 3 * 3, batch_size * 1 * 3 extrinsic calibration parameters - dist_coeffs: vector of distortion coefficients - orig_size: original size of image captured by the camera - Returns: For each point [X,Y,Z] in world coordinates [u,v,z] where u,v are the coordinates of the projection in - pixels and z is the depth - ''' - device = vertices.device - - K = torch.tensor([f, 0., cam['c'][0], 0., f, cam['c'][1], 0., 0., 1.]).view(3, 3)[None, ...].repeat( - vertices.shape[0], 1).to(device) - R = batch_rodrigues(cam['r'][None, ...].repeat(vertices.shape[0], 1)).to(device) - dist_coeffs = cam['k'][None, ...].repeat(vertices.shape[0], 1).to(device) - - vertices = torch.matmul(vertices, R.transpose(2, 1)) + t - x, y, z = vertices[:, :, 0], vertices[:, :, 1], vertices[:, :, 2] - x_ = x / (z + eps) - y_ = y / (z + eps) - - # Get distortion coefficients from vector - k1 = dist_coeffs[:, None, 0] - k2 = dist_coeffs[:, None, 1] - p1 = dist_coeffs[:, None, 2] - p2 = dist_coeffs[:, None, 3] - k3 = dist_coeffs[:, None, 4] - - # we use x_ for x' and x__ for x'' etc. - r = torch.sqrt(x_ ** 2 + y_ ** 2) - x__ = x_ * (1 + k1 * (r ** 2) + k2 * (r ** 4) + k3 * (r ** 6)) + 2 * p1 * x_ * y_ + p2 * (r ** 2 + 2 * x_ ** 2) - y__ = y_ * (1 + k1 * (r ** 2) + k2 * (r ** 4) + k3 * (r ** 6)) + p1 * (r ** 2 + 2 * y_ ** 2) + 2 * p2 * x_ * y_ - vertices = torch.stack([x__, y__, torch.ones_like(z)], dim=-1) - vertices = torch.matmul(vertices, K.transpose(1, 2)) - u, v = vertices[:, :, 0], vertices[:, :, 1] - v = orig_size - v - # map u,v from [0, img_size] to [-1, 1] to be compatible with the renderer - u = 2 * (u - orig_size / 2.) / orig_size - v = 2 * (v - orig_size / 2.) / orig_size - vertices = torch.stack([u, v, z], dim=-1) - - return vertices - - -def face_vertices(vertices, faces): - """ - :param vertices: [batch size, number of vertices, 3] - :param faces: [batch size, number of faces, 3] - :return: [batch size, number of faces, 3, 3] - """ - assert (vertices.ndimension() == 3) - assert (faces.ndimension() == 3) - assert (vertices.shape[0] == faces.shape[0]) - assert (vertices.shape[2] == 3) - assert (faces.shape[2] == 3) - - bs, nv = vertices.shape[:2] - bs, nf = faces.shape[:2] - device = vertices.device - faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] - vertices = vertices.reshape((bs * nv, 3)) - # pytorch only supports long and byte tensors for indexing - return vertices[faces.long()] - - -def vertex_normals(vertices, faces): - """ - :param vertices: [batch size, number of vertices, 3] - :param faces: [batch size, number of faces, 3] - :return: [batch size, number of vertices, 3] - """ - assert (vertices.ndimension() == 3) - assert (faces.ndimension() == 3) - assert (vertices.shape[0] == faces.shape[0]) - assert (vertices.shape[2] == 3) - assert (faces.shape[2] == 3) - - bs, nv = vertices.shape[:2] - bs, nf = faces.shape[:2] - device = vertices.device - normals = torch.zeros(bs * nv, 3).to(device) - - faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] # expanded faces - vertices_faces = vertices.reshape((bs * nv, 3))[faces.long()] - - faces = faces.view(-1, 3) - vertices_faces = vertices_faces.view(-1, 3, 3) - - normals.index_add_(0, faces[:, 1].long(), - torch.cross(vertices_faces[:, 2] - vertices_faces[:, 1], vertices_faces[:, 0] - vertices_faces[:, 1])) - normals.index_add_(0, faces[:, 2].long(), - torch.cross(vertices_faces[:, 0] - vertices_faces[:, 2], vertices_faces[:, 1] - vertices_faces[:, 2])) - normals.index_add_(0, faces[:, 0].long(), - torch.cross(vertices_faces[:, 1] - vertices_faces[:, 0], vertices_faces[:, 2] - vertices_faces[:, 0])) - - normals = F.normalize(normals, eps=1e-6, dim=1) - normals = normals.reshape((bs, nv, 3)) - # pytorch only supports long and byte tensors for indexing - return normals - - -def tensor_vis_landmarks(images, landmarks, gt_landmarks=None, color='g', isScale=True): - # visualize landmarks - vis_landmarks = [] - images = images.cpu().numpy() - predicted_landmarks = landmarks.detach().cpu().numpy() - if gt_landmarks is not None: - gt_landmarks_np = gt_landmarks.detach().cpu().numpy() - for i in range(images.shape[0]): - image = images[i] - image = image.transpose(1, 2, 0)[:, :, [2, 1, 0]].copy(); - image = (image * 255) - if isScale: - predicted_landmark = predicted_landmarks[i] * image.shape[0] / 2 + image.shape[0] / 2 - else: - predicted_landmark = predicted_landmarks[i] - - if predicted_landmark.shape[0] == 68: - image_landmarks = plot_kpts(image, predicted_landmark, color) - if gt_landmarks is not None: - image_landmarks = plot_verts(image_landmarks, - gt_landmarks_np[i] * image.shape[0] / 2 + image.shape[0] / 2, 'r') - else: - image_landmarks = plot_verts(image, predicted_landmark, color) - if gt_landmarks is not None: - image_landmarks = plot_verts(image_landmarks, - gt_landmarks_np[i] * image.shape[0] / 2 + image.shape[0] / 2, 'r') - - vis_landmarks.append(image_landmarks) - - vis_landmarks = np.stack(vis_landmarks) - vis_landmarks = torch.from_numpy( - vis_landmarks[:, :, :, [2, 1, 0]].transpose(0, 3, 1, 2)) / 255. # , dtype=torch.float32) - return vis_landmarks - - -end_list = np.array([17, 22, 27, 42, 48, 31, 36, 68], dtype = np.int32) - 1 -def plot_kpts(image, kpts, color = 'r'): - ''' Draw 68 key points - Args: - image: the input image - kpt: (68, 3). - ''' - if color == 'r': - c = (255, 0, 0) - elif color == 'g': - c = (0, 255, 0) - elif color == 'b': - c = (255, 0, 0) - image = image.copy() - kpts = kpts.copy() - - for i in range(kpts.shape[0]): - st = kpts[i, :2] - if kpts.shape[1]==4: - if kpts[i, 3] > 0.5: - c = (0, 255, 0) - else: - c = (0, 0, 255) - image = cv2.circle(image,(st[0], st[1]), 1, c, 2) - if i in end_list: - continue - ed = kpts[i + 1, :2] - image = cv2.line(image, (st[0], st[1]), (ed[0], ed[1]), (255, 255, 255), 1) - - return image - - -def save_obj(filename, vertices, faces, textures=None, uvcoords=None, uvfaces=None, texture_type='surface'): - assert vertices.ndimension() == 2 - assert faces.ndimension() == 2 - assert texture_type in ['surface', 'vertex'] - # assert texture_res >= 2 - - if textures is not None and texture_type == 'surface': - textures =textures.detach().cpu().numpy().transpose(1,2,0) - filename_mtl = filename[:-4] + '.mtl' - filename_texture = filename[:-4] + '.png' - material_name = 'material_1' - # texture_image, vertices_textures = create_texture_image(textures, texture_res) - texture_image = textures - texture_image = texture_image.clip(0, 1) - texture_image = (texture_image * 255).astype('uint8') - imsave(filename_texture, texture_image) - - faces = faces.detach().cpu().numpy() - - with open(filename, 'w') as f: - f.write('# %s\n' % os.path.basename(filename)) - f.write('#\n') - f.write('\n') - - if textures is not None: - f.write('mtllib %s\n\n' % os.path.basename(filename_mtl)) - - if textures is not None and texture_type == 'vertex': - for vertex, color in zip(vertices, textures): - f.write('v %.8f %.8f %.8f %.8f %.8f %.8f\n' % (vertex[0], vertex[1], vertex[2], - color[0], color[1], color[2])) - f.write('\n') - else: - for vertex in vertices: - f.write('v %.8f %.8f %.8f\n' % (vertex[0], vertex[1], vertex[2])) - f.write('\n') - - if textures is not None and texture_type == 'surface': - for vertex in uvcoords.reshape((-1, 2)): - f.write('vt %.8f %.8f\n' % (vertex[0], vertex[1])) - f.write('\n') - - f.write('usemtl %s\n' % material_name) - for i, face in enumerate(faces): - f.write('f %d/%d %d/%d %d/%d\n' % ( - face[0] + 1, uvfaces[i,0]+1, face[1] + 1, uvfaces[i,1]+1, face[2] + 1, uvfaces[i,2]+1)) - f.write('\n') - else: - for face in faces: - f.write('f %d %d %d\n' % (face[0] + 1, face[1] + 1, face[2] + 1)) - - if textures is not None and texture_type == 'surface': - with open(filename_mtl, 'w') as f: - f.write('newmtl %s\n' % material_name) - f.write('map_Kd %s\n' % os.path.basename(filename_texture)) \ No newline at end of file +import numpy as np +import torch +import torch.nn.functional as F +import os +from skimage.io import imsave +import cv2 +from PIL import Image +import matplotlib.pyplot as plt +import torchvision.transforms as transforms + + +def dict2obj(d): + if isinstance(d, list): + d = [dict2obj(x) for x in d] + if not isinstance(d, dict): + return d + + class C(object): + pass + + o = C() + for k in d: + o.__dict__[k] = dict2obj(d[k]) + return o + + +def check_mkdir(path): + if not os.path.exists(path): + print('making %s' % path) + os.makedirs(path) + + +def l2_distance(verts1, verts2): + return torch.sqrt(((verts1 - verts2) ** 2).sum(2)).mean(1).mean() + + +def quat2mat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, + 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx, + 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + + +def batch_rodrigues(theta): + # theta N x 3 + batch_size = theta.shape[0] + l1norm = torch.norm(theta + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(l1norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim=1) + + return quat2mat(quat) + + +def batch_orth_proj(X, camera): + ''' + X is N x num_points x 3 + ''' + camera = camera.clone().view(-1, 1, 3) + X_trans = X[:, :, :2] + camera[:, :, 1:] + X_trans = torch.cat([X_trans, X[:, :, 2:]], 2) + shape = X_trans.shape + # Xn = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape) + Xn = (camera[:, :, 0:1] * X_trans) + return Xn + + +def batch_persp_proj(vertices, cam, f, t, orig_size=256, eps=1e-9): + ''' + Calculate projective transformation of vertices given a projection matrix + Input parameters: + f: torch tensor of focal length + t: batch_size * 1 * 3 xyz translation in world coordinate + K: batch_size * 3 * 3 intrinsic camera matrix + R, t: batch_size * 3 * 3, batch_size * 1 * 3 extrinsic calibration parameters + dist_coeffs: vector of distortion coefficients + orig_size: original size of image captured by the camera + Returns: For each point [X,Y,Z] in world coordinates [u,v,z] where u,v are the coordinates of the projection in + pixels and z is the depth + ''' + device = vertices.device + + K = torch.tensor([f, 0., cam['c'][0], 0., f, cam['c'][1], 0., 0., 1.]).view(3, 3)[None, ...].repeat( + vertices.shape[0], 1).to(device) + R = batch_rodrigues(cam['r'][None, ...].repeat(vertices.shape[0], 1)).to(device) + dist_coeffs = cam['k'][None, ...].repeat(vertices.shape[0], 1).to(device) + + vertices = torch.matmul(vertices, R.transpose(2, 1)) + t + x, y, z = vertices[:, :, 0], vertices[:, :, 1], vertices[:, :, 2] + x_ = x / (z + eps) + y_ = y / (z + eps) + + # Get distortion coefficients from vector + k1 = dist_coeffs[:, None, 0] + k2 = dist_coeffs[:, None, 1] + p1 = dist_coeffs[:, None, 2] + p2 = dist_coeffs[:, None, 3] + k3 = dist_coeffs[:, None, 4] + + # we use x_ for x' and x__ for x'' etc. + r = torch.sqrt(x_ ** 2 + y_ ** 2) + x__ = x_ * (1 + k1 * (r ** 2) + k2 * (r ** 4) + k3 * (r ** 6)) + 2 * p1 * x_ * y_ + p2 * (r ** 2 + 2 * x_ ** 2) + y__ = y_ * (1 + k1 * (r ** 2) + k2 * (r ** 4) + k3 * (r ** 6)) + p1 * (r ** 2 + 2 * y_ ** 2) + 2 * p2 * x_ * y_ + vertices = torch.stack([x__, y__, torch.ones_like(z)], dim=-1) + vertices = torch.matmul(vertices, K.transpose(1, 2)) + u, v = vertices[:, :, 0], vertices[:, :, 1] + v = orig_size - v + # map u,v from [0, img_size] to [-1, 1] to be compatible with the renderer + u = 2 * (u - orig_size / 2.) / orig_size + v = 2 * (v - orig_size / 2.) / orig_size + vertices = torch.stack([u, v, z], dim=-1) + + return vertices + + +def face_vertices(vertices, faces): + """ + :param vertices: [batch size, number of vertices, 3] + :param faces: [batch size, number of faces, 3] + :return: [batch size, number of faces, 3, 3] + """ + assert (vertices.ndimension() == 3) + assert (faces.ndimension() == 3) + assert (vertices.shape[0] == faces.shape[0]) + assert (vertices.shape[2] == 3) + assert (faces.shape[2] == 3) + + bs, nv = vertices.shape[:2] + bs, nf = faces.shape[:2] + device = vertices.device + faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] + vertices = vertices.reshape((bs * nv, 3)) + # pytorch only supports long and byte tensors for indexing + return vertices[faces.long()] + + +def vertex_normals(vertices, faces): + """ + :param vertices: [batch size, number of vertices, 3] + :param faces: [batch size, number of faces, 3] + :return: [batch size, number of vertices, 3] + """ + assert (vertices.ndimension() == 3) + assert (faces.ndimension() == 3) + assert (vertices.shape[0] == faces.shape[0]) + assert (vertices.shape[2] == 3) + assert (faces.shape[2] == 3) + + bs, nv = vertices.shape[:2] + bs, nf = faces.shape[:2] + device = vertices.device + normals = torch.zeros(bs * nv, 3).to(device) + + faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] # expanded faces + vertices_faces = vertices.reshape((bs * nv, 3))[faces.long()] + + faces = faces.view(-1, 3) + vertices_faces = vertices_faces.view(-1, 3, 3) + + normals.index_add_(0, faces[:, 1].long(), + torch.cross(vertices_faces[:, 2] - vertices_faces[:, 1], vertices_faces[:, 0] - vertices_faces[:, 1])) + normals.index_add_(0, faces[:, 2].long(), + torch.cross(vertices_faces[:, 0] - vertices_faces[:, 2], vertices_faces[:, 1] - vertices_faces[:, 2])) + normals.index_add_(0, faces[:, 0].long(), + torch.cross(vertices_faces[:, 1] - vertices_faces[:, 0], vertices_faces[:, 2] - vertices_faces[:, 0])) + + normals = F.normalize(normals, eps=1e-6, dim=1) + normals = normals.reshape((bs, nv, 3)) + # pytorch only supports long and byte tensors for indexing + return normals + + +def tensor_vis_landmarks(images, landmarks, gt_landmarks=None, color='g', isScale=True): + # visualize landmarks + vis_landmarks = [] + images = images.cpu().numpy() + predicted_landmarks = landmarks.detach().cpu().numpy() + if gt_landmarks is not None: + gt_landmarks_np = gt_landmarks.detach().cpu().numpy() + for i in range(images.shape[0]): + image = images[i] + image = image.transpose(1, 2, 0)[:, :, [2, 1, 0]].copy(); + image = (image * 255) + if isScale: + predicted_landmark = predicted_landmarks[i] * image.shape[0] / 2 + image.shape[0] / 2 + else: + predicted_landmark = predicted_landmarks[i] + + if predicted_landmark.shape[0] == 68: + image_landmarks = plot_kpts(image, predicted_landmark, color) + if gt_landmarks is not None: + image_landmarks = plot_verts(image_landmarks, + gt_landmarks_np[i] * image.shape[0] / 2 + image.shape[0] / 2, 'r') + else: + image_landmarks = plot_verts(image, predicted_landmark, color) + if gt_landmarks is not None: + image_landmarks = plot_verts(image_landmarks, + gt_landmarks_np[i] * image.shape[0] / 2 + image.shape[0] / 2, 'r') + + vis_landmarks.append(image_landmarks) + + vis_landmarks = np.stack(vis_landmarks) + vis_landmarks = torch.from_numpy( + vis_landmarks[:, :, :, [2, 1, 0]].transpose(0, 3, 1, 2)) / 255. # , dtype=torch.float32) + return vis_landmarks + + +end_list = np.array([17, 22, 27, 42, 48, 31, 36, 68], dtype = np.int32) - 1 +def plot_kpts(image, kpts, color = 'r'): + ''' Draw 68 key points + Args: + image: the input image + kpt: (68, 3). + ''' + if color == 'r': + c = (255, 0, 0) + elif color == 'g': + c = (0, 255, 0) + elif color == 'b': + c = (255, 0, 0) + image = image.copy() + kpts = kpts.copy() + + for i in range(kpts.shape[0]): + st = kpts[i, :2] + if kpts.shape[1]==4: + if kpts[i, 3] > 0.5: + c = (0, 255, 0) + else: + c = (0, 0, 255) + image = cv2.circle(image,(int(st[0]), int(st[1])), 1, c, 2) + if i in end_list: + continue + ed = kpts[i + 1, :2] + image = cv2.line(image, (int(st[0]), int(st[1])), (int(ed[0]), int(ed[1])), (255, 255, 255), 1) + + return image + + +def save_obj(filename, vertices, faces, textures=None, uvcoords=None, uvfaces=None, texture_type='surface'): + assert vertices.ndimension() == 2 + assert faces.ndimension() == 2 + assert texture_type in ['surface', 'vertex'] + # assert texture_res >= 2 + + if textures is not None and texture_type == 'surface': + textures =textures.detach().cpu().numpy().transpose(1,2,0) + filename_mtl = filename[:-4] + '.mtl' + filename_texture = filename[:-4] + '.png' + material_name = 'material_1' + # texture_image, vertices_textures = create_texture_image(textures, texture_res) + texture_image = textures + texture_image = texture_image.clip(0, 1) + texture_image = (texture_image * 255).astype('uint8') + imsave(filename_texture, texture_image) + + faces = faces.detach().cpu().numpy() + + with open(filename, 'w') as f: + f.write('# %s\n' % os.path.basename(filename)) + f.write('#\n') + f.write('\n') + + if textures is not None: + f.write('mtllib %s\n\n' % os.path.basename(filename_mtl)) + + if textures is not None and texture_type == 'vertex': + for vertex, color in zip(vertices, textures): + f.write('v %.8f %.8f %.8f %.8f %.8f %.8f\n' % (vertex[0], vertex[1], vertex[2], + color[0], color[1], color[2])) + f.write('\n') + else: + for vertex in vertices: + f.write('v %.8f %.8f %.8f\n' % (vertex[0], vertex[1], vertex[2])) + f.write('\n') + + if textures is not None and texture_type == 'surface': + for vertex in uvcoords.reshape((-1, 2)): + f.write('vt %.8f %.8f\n' % (vertex[0], vertex[1])) + f.write('\n') + + f.write('usemtl %s\n' % material_name) + for i, face in enumerate(faces): + f.write('f %d/%d %d/%d %d/%d\n' % ( + face[0] + 1, uvfaces[i,0]+1, face[1] + 1, uvfaces[i,1]+1, face[2] + 1, uvfaces[i,2]+1)) + f.write('\n') + else: + for face in faces: + f.write('f %d %d %d\n' % (face[0] + 1, face[1] + 1, face[2] + 1)) + + if textures is not None and texture_type == 'surface': + with open(filename_mtl, 'w') as f: + f.write('newmtl %s\n' % material_name) + f.write('map_Kd %s\n' % os.path.basename(filename_texture)) + + +def face_seg(img, net, cropped_size): + face_area = [1, 2, 3, 4, 5, 6, 10, 11, 12, 13] + to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + pil_image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + resize_pil_image = pil_image.resize((512, 512), Image.BILINEAR) + tensor_image = to_tensor(resize_pil_image) + tensor_image = torch.unsqueeze(tensor_image, 0) + tensor_image = tensor_image.cuda() + out = net(tensor_image)[0] + parsing = out.squeeze(0).cpu().detach().numpy().argmax(0) + vis_parsing_anno = parsing.copy().astype(np.uint8) + vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1])) + num_of_class = np.max(vis_parsing_anno) + + for pi in range(1, num_of_class + 1): + if pi in face_area: + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1]] = 1 + image_mask = cv2.resize(vis_parsing_anno_color, (cropped_size, cropped_size)) + image_mask = image_mask[..., None].astype('float32') + image_mask = image_mask.transpose(2, 0, 1) + image_mask_bn = np.zeros_like(image_mask) + image_mask_bn[np.where(image_mask != 0)] = 1. + + return image_mask_bn[None, :, :, :] + + +def crop_img(ori_image, rect, cropped_size): + l, t, r, b = rect + center_x = r - (r - l) // 2 + center_y = b - (b - t) // 2 + w = (r - l) * 1.2 + h = (b - t) * 1.2 + crop_size = max(w, h) + if crop_size > cropped_size: + crop_ly = int(max(0, center_y - crop_size // 2)) + crop_lx = int(max(0, center_x - crop_size // 2)) + crop_ly = int(min(ori_image.shape[0] - crop_size, crop_ly)) + crop_lx = int(min(ori_image.shape[1] - crop_size, crop_lx)) + crop_image = ori_image[crop_ly: int(crop_ly + crop_size), crop_lx: int(crop_lx + crop_size), :] + else: + + crop_ly = int(max(0, center_y - cropped_size // 2)) + crop_lx = int(max(0, center_x - cropped_size // 2)) + crop_ly = int(min(ori_image.shape[0] - cropped_size, crop_ly)) + crop_lx = int(min(ori_image.shape[1] - cropped_size, crop_lx)) + crop_image = ori_image[crop_ly: int(crop_ly + cropped_size), crop_lx: int(crop_lx + cropped_size), :] + new_rect = [l - crop_lx, t - crop_ly, r - crop_lx, b - crop_ly] + return crop_image, new_rect + + +def resize_para(ori_frame): + w, h, c = ori_frame.shape + d = max(w, h) + scale_to = 640 if d >= 1280 else d / 2 + scale_to = max(64, scale_to) + input_scale = d / scale_to + w = int(w / input_scale) + h = int(h / input_scale) + image_info = [w, h, input_scale] + return image_info + + +def draw_train_process(title, iters, loss, label_loss): + plt.title(title, fontsize=24) + plt.xlabel("iter", fontsize=20) + plt.ylabel("loss", fontsize=20) + plt.plot(iters, loss, color='red', label=label_loss) + plt.legend() + plt.grid() + plt.show() \ No newline at end of file