Source code for gnes.encoder.base

from typing import List, Any, Tuple, Union

import numpy as np

from ..base import TrainableBase, CompositionalTrainableBase


[docs]class BaseEncoder(TrainableBase):
[docs] def encode(self, data: Any, *args, **kwargs) -> Any: pass
def _copy_from(self, x: 'BaseEncoder') -> None: pass
[docs]class BaseImageEncoder(BaseEncoder):
[docs] def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray: pass
[docs]class BaseVideoEncoder(BaseEncoder):
[docs] def encode(self, data: List['np.ndarray'], *args, **kwargs) -> Union[np.ndarray, List['np.ndarray']]: pass
[docs]class BaseTextEncoder(BaseEncoder):
[docs] def encode(self, text: List[str], *args, **kwargs) -> Union[Tuple, np.ndarray]: pass
[docs]class BaseNumericEncoder(BaseEncoder): """Note that all NumericEncoder can not be used as the first encoder of the pipeline"""
[docs] def encode(self, data: np.ndarray, *args, **kwargs) -> np.ndarray: pass
[docs]class BaseAudioEncoder(BaseEncoder):
[docs] def encode(self, data: List['np.ndarray'], *args, **kwargs) -> np.ndarray: pass
[docs]class BaseBinaryEncoder(BaseEncoder):
[docs] def encode(self, data: np.ndarray, *args, **kwargs) -> bytes: if data.dtype != np.uint8: raise ValueError('data must be np.uint8 but received %s' % data.dtype) return data.tobytes()
[docs]class PipelineEncoder(CompositionalTrainableBase):
[docs] def encode(self, data: Any, *args, **kwargs) -> Any: if not self.components: raise NotImplementedError for be in self.components: data = be.encode(data, *args, **kwargs) return data
[docs] def train(self, data, *args, **kwargs): if not self.components: raise NotImplementedError for idx, be in enumerate(self.components): if not be.is_trained: be.train(data, *args, **kwargs) if idx + 1 < len(self.components): data = be.encode(data, *args, **kwargs)