Source code for gnes.encoder.numeric.pooling

import os
from typing import Tuple

import numpy as np

from ..base import BaseNumericEncoder
from ...helper import as_numpy_array


[docs]class PoolingEncoder(BaseNumericEncoder): def __init__(self, pooling_strategy: str = 'REDUCE_MEAN', backend: str = 'numpy', *args, **kwargs): super().__init__(*args, **kwargs) valid_poolings = {'REDUCE_MEAN', 'REDUCE_MAX', 'REDUCE_MEAN_MAX'} valid_backends = {'tensorflow', 'numpy', 'pytorch', 'torch'} if pooling_strategy not in valid_poolings: raise ValueError('"pooling_strategy" must be one of %s' % valid_poolings) if backend not in valid_backends: raise ValueError('"backend" must be one of %s' % valid_backends) self.pooling_strategy = pooling_strategy self.backend = backend
[docs] def post_init(self): if self.backend in {'pytorch', 'torch'}: import torch self.torch = torch elif self.backend == 'tensorflow': os.environ['CUDA_VISIBLE_DEVICES'] = '0' if self.on_gpu else '-1' import tensorflow as tf self._tf_graph = tf.Graph() config = tf.ConfigProto(device_count={'GPU': 1 if self.on_gpu else 0}) config.gpu_options.allow_growth = True config.log_device_placement = False self._sess = tf.Session(graph=self._tf_graph, config=config) self.tf = tf
[docs] def mul_mask(self, x, m): if self.backend in {'pytorch', 'torch'}: return self.torch.mul(x, m.unsqueeze(2)) elif self.backend == 'tensorflow': with self._tf_graph.as_default(): return x * self.tf.expand_dims(m, axis=-1) elif self.backend == 'numpy': return x * np.expand_dims(m, axis=-1)
[docs] def minus_mask(self, x, m, offset: int = 1e30): if self.backend in {'pytorch', 'torch'}: return x - (1.0 - m).unsqueeze(2) * offset elif self.backend == 'tensorflow': with self._tf_graph.as_default(): return x - self.tf.expand_dims(1.0 - m, axis=-1) * offset elif self.backend == 'numpy': return x - np.expand_dims(1.0 - m, axis=-1) * offset
[docs] def masked_reduce_mean(self, x, m, jitter: float = 1e-10): if self.backend in {'pytorch', 'torch'}: return self.torch.div(self.torch.sum(self.mul_mask(x, m), dim=1), self.torch.sum(m.unsqueeze(2), dim=1) + jitter) elif self.backend == 'tensorflow': with self._tf_graph.as_default(): return self.tf.reduce_sum(self.mul_mask(x, m), axis=1) / ( self.tf.reduce_sum(m, axis=1, keepdims=True) + jitter) elif self.backend == 'numpy': return np.sum(self.mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + jitter)
[docs] def masked_reduce_max(self, x, m): if self.backend in {'pytorch', 'torch'}: return self.torch.max(self.minus_mask(x, m), 1)[0] elif self.backend == 'tensorflow': with self._tf_graph.as_default(): return self.tf.reduce_max(self.minus_mask(x, m), axis=1) elif self.backend == 'numpy': return np.max(self.minus_mask(x, m), axis=1)
[docs] @as_numpy_array def encode(self, data: Tuple, *args, **kwargs): seq_tensor, mask_tensor = data if self.pooling_strategy == 'REDUCE_MEAN': r = self.masked_reduce_mean(seq_tensor, mask_tensor) elif self.pooling_strategy == 'REDUCE_MAX': r = self.masked_reduce_max(seq_tensor, mask_tensor) elif self.pooling_strategy == 'REDUCE_MEAN_MAX': if self.backend in {'pytorch', 'torch'}: r = self.torch.cat((self.masked_reduce_mean(seq_tensor, mask_tensor), self.masked_reduce_max(seq_tensor, mask_tensor)), dim=1) elif self.backend == 'tensorflow': with self._tf_graph.as_default(): r = self.tf.concat([self.masked_reduce_mean(seq_tensor, mask_tensor), self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1) elif self.backend == 'numpy': r = np.concatenate([self.masked_reduce_mean(seq_tensor, mask_tensor), self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1) if self.backend == 'tensorflow': r = self._sess.run(r) return r