Source code for gnes.preprocessor.base

import io

import numpy as np

from ..base import TrainableBase, CompositionalTrainableBase
from ..proto import gnes_pb2, array2blob

[docs]class BasePreprocessor(TrainableBase): doc_type = gnes_pb2.Document.UNKNOWN def __init__(self, uniform_doc_weight: bool = True, *args, **kwargs): super().__init__(*args, **kwargs) self.uniform_doc_weight = uniform_doc_weight
[docs] def apply(self, doc: 'gnes_pb2.Document') -> None: doc.doc_type = self.doc_type if not doc.weight and self.uniform_doc_weight: doc.weight = 1.0
[docs]class BaseTextPreprocessor(BasePreprocessor): doc_type = gnes_pb2.Document.TEXT
[docs]class BaseAudioPreprocessor(BasePreprocessor): doc_type = gnes_pb2.Document.AUDIO
[docs]class BaseImagePreprocessor(BasePreprocessor): doc_type = gnes_pb2.Document.IMAGE
[docs]class BaseVideoPreprocessor(BasePreprocessor): doc_type = gnes_pb2.Document.VIDEO
[docs]class PipelinePreprocessor(CompositionalTrainableBase):
[docs] def apply(self, doc: 'gnes_pb2.Document') -> None: if not self.components: raise NotImplementedError for be in self.components: be.apply(doc)
[docs] def train(self, data, *args, **kwargs): if not self.components: raise NotImplementedError for idx, be in enumerate(self.components): be.train(data, *args, **kwargs) if idx + 1 < len(self.components): data = be.apply(data, *args, **kwargs)
[docs]class UnaryPreprocessor(BasePreprocessor): is_trained = True def __init__(self, doc_type: int, *args, **kwargs): super().__init__(*args, **kwargs) self.doc_type = doc_type
[docs] def apply(self, doc: 'gnes_pb2.Document'): super().apply(doc) c = doc.chunks.add() c.doc_id = doc.doc_id c.offset = 0 c.weight = 1. if doc.raw_bytes: self.raw_to_chunk(c, doc.raw_bytes) else: self.logger.error('bad document: "raw_bytes" is empty!')
[docs] def raw_to_chunk(self, chunk: 'gnes_pb2.Chunk', raw_bytes: bytes): if self.doc_type == gnes_pb2.Document.TEXT: chunk.text = raw_bytes.decode() elif self.doc_type == gnes_pb2.Document.IMAGE: from PIL import Image img = np.array( chunk.blob.CopyFrom(array2blob(img)) elif self.doc_type == gnes_pb2.Document.VIDEO: raise NotImplementedError else: raise NotImplementedError
[docs]class RawChunkPreprocessor(BasePreprocessor): @staticmethod def _parse_chunk(chunk: 'gnes_pb2.Chunk', *args, **kwargs): raise NotImplementedError
[docs] def apply(self, doc: 'gnes_pb2.Document') -> None: if doc.raw_bytes: for chunk in doc.chunks: chunk.raw = self._parse_chunk(chunk) else: self.logger.error('bad document: "raw_bytes" is empty!')