Teoría y tutorial de implementación de Pytorch para encontrar la pose de un objeto a partir de una sola imagen monocular
Uno de los problemas centrales de la visión por ordenador y la robótica es la comprensión de cómo se posicionan los objetos con respecto al robot o el entorno. En este post, explicaré la teoría subyacente y daré un tutorial de implementación de pytorch del artículo “Posición del objeto 6-DoF desde puntos clave clave semántica” por Pavlakos et al .
En este enfoque, hay dos pasos. El primer paso es predecir los “puntos clave semánticos” en la imagen 2D. En el segundo paso, estimamos la pose del objeto maximizando la consistencia geométrica entre el conjunto predicho de puntos clave semánticos y un modelo 3D del objeto utilizando un modelo de cámara en perspectiva.
Localización de punto clave
Primero, colocamos un cuadro delimitador en el objeto de interés usando un algoritmo de detección de objetos estándar, como Faster-RCNN. Luego consideramos solo la región del objeto para la localización del punto clave. El papel utiliza la arquitectura de red de “reloj de arena apilado” para este propósito. La red toma una imagen RGB y genera un conjunto de mapas de calor, un mapa de calor para cada punto clave. Los mapas de calor permiten a la red expresar su confianza en una región en lugar de retroceder una sola posición x, y para un punto clave.
Como se puede ver en la imagen anterior, la red tiene dos relojes de arena. Además, cada reloj de arena tiene una parte de disminución de resolución y una parte de mejora de muestra. El propósito del segundo reloj de arena es refinar la salida del primer reloj de arena. Las partes de submuestreo consisten en capas de convolución alternas y de agrupación máxima. Cuando la salida alcanza una resolución de 4 X 4, comienza el muestreo ascendente. Las partes de muestreo ascendente consisten en capas de convolución y muestreo ascendente (deconvolución). La red se entrena usando la pérdida de L2 en las salidas de la primera y la segunda relojes de arena.
Optimización de pose
Ahora que tenemos Los puntos clave, podemos usarlos para encontrar la pose. Todo lo que necesitamos es un modelo del objeto en el que estamos interesados. Los autores del artículo definen un modelo S deformable que se compone de una forma media B_0 agregada con una serie de variaciones B_i que se calculan con PCA. La forma se define como una matriz de 3xP donde P es el número de puntos clave.
El principal problema de optimización es la reducción del residual inferior. Aquí W es el conjunto de puntos clave 2D normalizados en coordenadas homogéneas. Z es una matriz diagonal que representa la profundidad de los puntos clave. R y T son la matriz de rotación y los vectores de traducción respectivamente. Los aspectos desconocidos sobre los que optimizamos son Z, R, T y C.
A continuación se muestra la pérdida real que queremos minimizar. Aquí, D es la confianza de los puntos clave. Hacemos esto para penalizar el error en los puntos clave sobre los que la red está más segura. El segundo término de la pérdida es un término de regularización destinado a penalizar grandes desviaciones de la forma media del modelo de interés.
¡Eso es todo! El resultado R y T definen la pose del objeto. Ahora echemos un vistazo a la implementación de pytorch.
Implementación de Pytorch
Para la implementación, seguiremos de cerca el código provisto en CIS 580 en la Universidad de Pennsylvania. He simplificado el código fusionando algunos archivos y eliminando algunos pasos de aumento de datos.
Vamos a sumergirnos en el código. Primero, necesitamos clonar un repositorio:
git clone https://github.com/vaishak2future/posefromkeypoints.git
Necesitamos descomprimir el data.zip para que el directorio de nivel superior contenga tres carpetas : datos, salida, y utils. Ahora, vamos a ejecutar el cuaderno de Jupyter. El primer bloque de código que inspeccionaremos es la clase Trainer. Esta clase carga el tren y prueba los conjuntos de datos y ejecuta algunas transformaciones para que estén en el formato que deseamos. Los datos se recortan y se rellenan para que solo veamos el cuadro delimitador alrededor del objeto de interés. Luego, las ubicaciones de los puntos clave de la verdad terrestre se expresan como mapas de calor. Luego, los datos se convierten en tensores adecuadamente formateados y se normalizan. Finalmente, el entrenador también carga el modelo de reloj de arena. Una vez que llamamos el método del tren, hemos terminado con la localización del punto clave.
clase Entrenador (objeto): def __init __ (self): self.device = torch.device (& # 039; cuda & # 039; if torch.cuda.is_available () else & # 039; cpu & # 039;) train_transform_list = [19659023] test_transform_list = [CropAndPad(out_size=(256, 256)),LocsToHeatmaps(out_size=(64, 64)),ToTensor(),Normalize()] self.train_ds = Dataset (is_train = Verdadero transform = transforms.Compose (train_transform_list)) ] ] ] ] ] transform = transforms.Compose (test_transform_list)) self.model = hg (num_stacks = 1, num_blocks = 1, num_classes = 10) .to (self.device) # define la función de pérdida y el optimizador self .heatmap_loss = torch.nn.MSELoss (). to (self.device) # para Global loss self.optimizer = torch.optim.RMSprop (self.model.parameters (), lr = 2.5e-4) self.train_data_loader = DataLoader (self.train_ds, batch_size = 8, num_workers = 8, pin_memory = True ] shuffle = True ] self.test_data_loader = DataLoader (self.test_ds, batch_size = 32, num_workers = 8, pin_memory = True shuffle = True 19659022] self.summary_iters = [] self.losses = [] self.pcks = [] def train (self): self.total_step_count = 0 start_time = time ( ) para época en el rango de (1,400 + 1): impresión ("Época % d / % d "% (época, 400)) para paso, lote en enumeración (self.train_data_loader): self.model.train () batch = {k : v.to (self.device) if isinstance (v, torch.Tensor) else v para k, v en lote .items ()} self.optimizer.zero_grad () pred_heatmap_list = self.model (batch ['image']) loss = self.heatmap_loss (pred_heatmap_list [-1]batch ['keypoint_heatmaps']) ['keypoint_heatmaps'] ] loss.backward () self.optimizer.step () self.total_step_count + = 1 checkpoint = {& # 039; model & # 039 ;: self.model.state_dict () } torch.save (checkpoint, & # 039; ./ output / model_checkpoint.pt & # 039;)
Antes de comenzar la optimización de la postura, debemos definir una costumbre Función que vamos a utilizar mucho. La función de Rodrigues nos ayudará a convertir los vectores de representación del ángulo del eje en matrices de rotación 3×3. Lo definimos de esta manera para que podamos usar la funcionalidad autograd de pytorch.
class Rodrigues (torch.autograd.Function): @staticmethod def forward ( self, inp): pose = inp.detach (). cpu (). numpy () rotm, part_jacob = cv2.Rodrigues (pose) self.jacob = torch.Tensor (np.transpose (part_jacob)). contiguo () rotation_matrix = torch.Tensor (rotm.ravel ()) return rotation_matrix.view (3,3) @staticmethod def hacia atrás (self, grad_output): grad_output = grad_output.view (1, -1) grad_input = torch.mm (grad_output, self.jacob) grad_input = grad_input.view (-1) [19659033] return grad_input rodrigues = Rodrigues.apply
Finalmente, escribimos la función de optimización de la postura donde convertimos nuestros puntos clave 2D en coordenadas homogéneas normalizadas usando el modelo de cámara y luego compárelos con los puntos clave 3D de Ground Ground rotados y traducidos. Detenemos la optimización cuando la pérdida es inferior a nuestro umbral.
def pose_optimization (img, vértices, caras, keypoints_2d, conf, keypoints_3d, K): # Enviar variables a GPU [1965905050] device = keypoints_2d.device keypoints_3d = keypoints_3d.to (device) K = K.to (device) r = torch.rand (3, require_grad = True device = dispositivo) # rotación en representación de ángulo de eje t = torch.rand (3, require_grad = Verdadero device = device) d = conf.sqrt () [: Ninguna ] # Puntos clave 2D en coordenadas normalizadas norm_keypoints_2d = torch.matmul (K.inverse (), torch.cat ((keypoints_2d, torch.ones (keypoints_2d.shape [1945904444] , 1, dispositivo = dispositivo)), dim = -1) .t ()). T () [:,:-1] # configurar optimizador optimizador = torch.optim.Adam ([r,t]lr = 1e-2) # cheque de convergencia convergente = Falso rel_tol = 1e-7 [1 9459024] loss_old = 100 mientras que no convergió: optimizer.zero_grad () # convierte el ángulo del eje a la matriz de rotación R = rodrigues (r) # 1) Calcular los puntos clave proyectados según la estimación actual de R y t k3d = torch.matmul (R, puntos clave_3d.transpose (1, 0)) + t [: Ninguno ] proj_keypoints = (k3d / k3d [2]) [0:2,:] .transpose (1,0) # 2) Error de cálculo (basado en la distancia entre los puntos clave proyectados y los puntos clave detectados) err = torch.norm ((norm_keypoints_2d - proj_keypoints) * d) ** 2, & # 039; para & # 039;) # 3) Actualización basada en error err.backward () optimizer.step () # 4 ) Compruebe la convergencia si abs (err.detach () - loss_old) / loss_old <rel_tol: break else loss loss_old = err .detach () # print (err.detach (). cpu (). numpy ()) R = rodrigues (r) plt.figure () plot_mesh (img, vertices, faces, R.detach (). Cpu (). Numpy (), t.detach (). cpu (). numpy () [: Ninguna ]K.detach (). cpu (). numpy ()) plt.show () return rodrigues ( r) [0] .detach (), t.detach ()
Cuando ejecuta la función anterior, debería obtener resultados asombrosos como estos:
¡Eso es todo para los detalles importantes de la implementación! Asegúrese de ejecutarlo y jugar con los diferentes parámetros y transformaciones para comprender cómo afectan los resultados. Espero que esto haya sido de ayuda. ¡Por favor, envíenme cualquier comentario o corrección!
Referencias
[1] G.Pavlakos et al. 2017. Postura del objeto 6-DoF de puntos clave semánticos
Aprendizaje profundo geométrico para la estimación de posturas se publicó originalmente en Hacia la ciencia de datos en Medio, donde las personas continúan la conversación resaltando y respondiendo a esto historia