Source code for gnes.service.encoder

from typing import List, Union

from .base import BaseService as BS, MessageHandler
from ..proto import gnes_pb2, array2blob, blob2array


[docs]class EncoderService(BS): handler = MessageHandler(BS.handler)
[docs] def post_init(self): from ..encoder.base import BaseEncoder self._model = self.load_model(BaseEncoder) self.train_data = []
[docs] def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.Document'], do_encoding: bool = True, is_input_list: bool = True): if not is_input_list: docs = [docs] contents = [] chunks = [] for d in docs: if not d.chunks: self.logger.warning('document (doc_id=%s) contains no chunks!' % d.doc_id) continue for c in d.chunks: if d.doc_type == gnes_pb2.Document.TEXT: contents.append(c.text) elif d.doc_type in {gnes_pb2.Document.IMAGE, gnes_pb2.Document.VIDEO}: contents.append(blob2array(c.blob)) else: self.logger.warning( 'chunk content is in type: %s, dont kow how to handle that, ignored' % c.WhichOneof('content')) chunks.append(c) if do_encoding and contents: try: embeds = self._model.encode(contents) if len(chunks) != embeds.shape[0]: self.logger.error( 'mismatched %d chunks and a %s shape embedding, ' 'the first dimension must be the same' % (len(chunks), embeds.shape)) for idx, c in enumerate(chunks): c.embedding.CopyFrom(array2blob(embeds[idx])) except Exception as ex: self.logger.error(ex, exc_info=True) self.logger.warning('encoder service throws an exception, ' 'the sequel pipeline may not work properly') return contents
@handler.register(gnes_pb2.Request.IndexRequest) def _handler_index(self, msg: 'gnes_pb2.Message'): self.embed_chunks_in_docs(msg.request.index.docs) @handler.register(gnes_pb2.Request.TrainRequest) def _handler_train(self, msg: 'gnes_pb2.Message'): if msg.request.train.docs: contents = self.embed_chunks_in_docs(msg.request.train.docs, do_encoding=False) self.train_data.extend(contents) msg.response.train.status = gnes_pb2.Response.PENDING # raise BlockMessage if msg.request.train.flush: self._model.train(self.train_data) self.logger.info('%d samples is flushed for training' % len(self.train_data)) self.is_model_changed.set() self.train_data.clear() msg.response.control.status = gnes_pb2.Response.SUCCESS @handler.register(gnes_pb2.Request.QueryRequest) def _handler_search(self, msg: 'gnes_pb2.Message'): self.embed_chunks_in_docs(msg.request.search.query, is_input_list=False) @handler.register_hook(hook_type=('pre', 'post'), only_when_verbose=True) def _hook_debug_msg(self, msg: 'gnes_pb2.Message', *args, **kwargs): from pprint import pformat debug_kv = { 'envelope': lambda: msg.envelope, 'num_docs': lambda: len(msg.request.index.docs), 'num_chunks in doc[0]': lambda: len(msg.request.index.docs[0].chunks), 'docs[0].chunks[0].content_type': lambda: msg.request.index.docs[0].chunks[0].WhichOneof('content'), 'docs[0].chunks[0].weight': lambda: msg.request.index.docs[0].chunks[0].weight, 'docs[0].chunks[0].embedding': lambda: blob2array(msg.request.index.docs[0].chunks[0].embedding), 'docs[0].chunks[0].embedding[0]': lambda: blob2array(msg.request.index.docs[0].chunks[0].embedding)[0] } debug_info = {} for k, v in debug_kv.items(): try: r = v() except Exception as ex: r = 'fail to get the value, reason: %s' % ex debug_info[k] = r self.logger.info(pformat(debug_info))