SiamMask es una arquitectura de modelo de deep learning que realiza el seguimiento visual de objetos (VOT) y la segmentación semisupervisada de objeto en vídeos (VOS). Dada la ubicación del objeto en el primer frame de la secuencia, el objetivo de VOT es estimar la posición de un objeto en los sisguientes frames con la mayor precisión posible. De la misma forma, el objetivo principal de VOS es obtener una máscara de segmentación binaria que exprese si un píxel pertenece o no al objetivo.
SiamMask es una red siamesa que añade una nueva rama para producir una máscara binaria a nivel de píxel.
Hay dos variantes, con dos y con tres ramas.
# Cargar clase del módulo
from arcgis.learn import SiamMask
# Iniciar SiamMask. Hay dos formas:
## Con pesos pre-entrenados DAVIS (sin pasar parámetros)
ot = SiamMask()
## Entrenar modelo con datos propios con el método prepare_data
from arcgis.learn import prepare_data
data = prepare_data(r"path_to_data_folder", dataset_type="ObjectTracking", batch_size=64)
ot = SiamMask(data)
# Inicializar el objeto
## Inicializar objetos utilizando bounding box. Los parámetros que usamos son:
## * frame: array de numpy donde se inicialia el seguimiento del objeto
## * detections: lista de cajas delimitadoras para inicializar el objeto como [[x,y,w,h]] = x-coordinate, y-coordinate, width of bbox, height of bbox
## * labels: lista de etiquetas que representa la clase del objeto
## * reset: booleano que podemos poner a True si queremos resetear de los valores del seguimiento anterior
tracks = ot.init(img, [[x,y,w,h]], [['truck']])
# Actualizar la localización del objeto en los siguientes frames
## Se pasa como parámetro el frame, el array de numpy para actualizar la posición del objeto
state = ot.update(frame)
Si ejecutas este código hay que tener en cuenta que:
import numpy as np
import cv2
from arcgis.learn import SiamMask
ot = SiamMask() # O cargar modelo concreto con SiamMask.from_model("path_to_save_model")
cap = cv2.VideoCapture(r"path_to_video_file")
# Aquí se empieza a analizar el vídeo
initialized = False
while(True):
ret, frame = cap.read()
if ret is False:
break
if initialized:
state = ot.update(frame) ## Update the track location in the frame
for track in state:
mask = track.mask
frame[:, :, 2] = (mask > 0) * 255 + (mask == 0) * frame[:, :, 2]
cv2.polylines(frame, [np.int0(track.location).reshape((-1, 1, 2))], True, (w, 255, h), 1)
cv2.imshow('frame',frame)
key = cv2.waitKey(1)
if key & 0xFF == ord('q'):
break
if key == 32:
init_rect = cv2.selectROI('frame', frame, False, False)
values = np.array(init_rect)
if all(values == 0):
continue
x, y, w, h = init_rect
state = ot.init(frame, [[x,y,w,h]]) ## Initialize the track in the frame
initialized = True
cv2.waitKey()
cap.release()
cv2.destroyAllWindows()
Ejemplo de ejecución en ArcGIS Pro: