import io
import os
from typing import List
import numpy as np
from PIL import Image
from .resize import SizedPreprocessor
from ..helper import torch_transform, get_all_subarea
from ...proto import array2blob
[docs]class SegmentPreprocessor(SizedPreprocessor):
def __init__(self, model_name: str,
model_dir: str,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.model_name = model_name
self.model_dir = model_dir
[docs] def post_init(self):
import torch
import torchvision.models as models
os.environ['TORCH_HOME'] = self.model_dir
self._model = getattr(models.detection, self.model_name)(pretrained=True)
self._model = self._model.eval()
if self.on_gpu:
# self._model.cuda()
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._model = self._model.to(self._device)
[docs] def apply(self, doc: 'gnes_pb2.Document'):
super().apply(doc)
if doc.raw_bytes:
original_image = Image.open(io.BytesIO(doc.raw_bytes))
all_subareas, index = get_all_subarea(original_image)
image_tensor = torch_transform(original_image)
if self.on_gpu:
image_tensor = image_tensor.cuda()
seg_output = self._model([image_tensor])
weight = seg_output[0]['scores'].tolist()
length = len(list(filter(lambda x: x >= 0.5, weight)))
chunks = seg_output[0]['boxes'].tolist()[:length]
weight = weight[:length]
for ci, ele in enumerate(zip(chunks, weight)):
c = doc.chunks.add()
c.doc_id = doc.doc_id
c.blob.CopyFrom(array2blob(self._crop(original_image, ele[0])))
c.offset = ci
c.offset_nd.extend(self._get_seg_offset_nd(all_subareas, index, ele[0]))
c.weight = self._cal_area(ele[0]) / (original_image.size[0] * original_image.size[1])
c = doc.chunks.add()
c.doc_id = doc.doc_id
c.blob.CopyFrom(array2blob(np.array(original_image)))
c.offset = len(chunks)
c.offset_nd.extend([100, 100])
c.weight = 1.
else:
self.logger.error('bad document: "raw_bytes" is empty!')
def _get_seg_offset_nd(self, all_subareas: List[List[int]], index: List[List[int]], chunk: List[int]) -> List[int]:
iou_list = [self._cal_iou(area, chunk) for area in all_subareas]
return index[int(np.argmax(iou_list))][:2]
@staticmethod
def _crop(original_image, coordinates):
return np.array(original_image.crop(coordinates))
@staticmethod
def _cal_area(coordinate: List[int]):
return (coordinate[2] - coordinate[0]) * (coordinate[3] - coordinate[1])
def _cal_iou(self, image: List[int], chunk: List[int]) -> float:
chunk_area = self._cal_area(chunk)
image_area = self._cal_area(image)
x1 = max(chunk[0], image[0])
y1 = max(chunk[1], image[1])
x2 = min(chunk[2], image[2])
y2 = min(chunk[3], image[3])
overlap_area = max(0, x2 - x1) * max(0, y2 - y1)
iou = overlap_area / (chunk_area + image_area - overlap_area)
return iou