Source code for gnes.service.frontend

import os
import threading
from concurrent.futures import ThreadPoolExecutor

import grpc

from ..client.base import ZmqClient
from ..helper import set_logger
from ..proto import gnes_pb2_grpc, add_envelope


[docs]class FrontendService: def __init__(self, args): if not args.proxy: os.unsetenv('http_proxy') os.unsetenv('https_proxy') self.logger = set_logger(self.__class__.__name__, args.verbose) self.server = grpc.server( ThreadPoolExecutor(max_workers=args.max_concurrency), options=[('grpc.max_send_message_length', args.max_message_size), ('grpc.max_receive_message_length', args.max_message_size)]) self.logger.info('start a frontend with %d workers' % args.max_concurrency) gnes_pb2_grpc.add_GnesRPCServicer_to_server(self._Servicer(args), self.server) self.bind_address = '{0}:{1}'.format(args.grpc_host, args.grpc_port) self.server.add_insecure_port(self.bind_address) self._stop_event = threading.Event() def __enter__(self): self.server.start() self.logger.critical('listening at: %s' % self.bind_address) self._stop_event.clear() return self def __exit__(self, exc_type, exc_val, exc_tb): self.server.stop(None) self.stop()
[docs] def stop(self): self._stop_event.set()
[docs] def join(self): self._stop_event.wait()
class _Servicer(gnes_pb2_grpc.GnesRPCServicer): def __init__(self, args): self.args = args self.logger = set_logger(FrontendService.__name__, args.verbose) self.zmq_context = self.ZmqContext(args) self.request_id_cnt = 0 self.send_recv_kwargs = dict( check_version=self.args.check_version, timeout=self.args.timeout, squeeze_pb=self.args.squeeze_pb) self.pending_request = 0 def Call(self, request, context): with self.zmq_context as zmq_client: zmq_client.send_message(add_envelope(request, zmq_client), **self.send_recv_kwargs) m = zmq_client.recv_message(**self.send_recv_kwargs) return m def StreamCall(self, request_iterator, context): self.pending_request = 0 def get_response(num_recv, blocked=False): if blocked: self.logger.info('waiting for %d responses ...' % num_recv) for _ in range(num_recv): if blocked or zmq_client.receiver.poll(1): msg = zmq_client.recv_message(**self.send_recv_kwargs) self.pending_request -= 1 yield msg while zmq_client.receiver.poll(1): msg = zmq_client.recv_message(**self.send_recv_kwargs) self.pending_request -= 1 yield msg with self.zmq_context as zmq_client: for request in request_iterator: self.logger.info('receive request: %s' % request.request_id) num_recv = max(self.pending_request - self.args.max_pending_request, 1) yield from get_response(num_recv, num_recv > 1) self.logger.info('send new request into %d appending tasks' % (self.pending_request)) zmq_client.send_message(add_envelope(request, zmq_client), **self.send_recv_kwargs) self.pending_request += 1 self.logger.info('all requests are sent, waiting for the responses...') yield from get_response(self.pending_request, blocked=True) class ZmqContext: """The zmq context class.""" def __init__(self, args): self.args = args self.tlocal = threading.local() self.tlocal.client = None def __enter__(self): """Enter the context.""" client = ZmqClient(self.args) self.tlocal.client = client return client def __exit__(self, exc_type, exc_value, exc_traceback): """Exit the context.""" self.tlocal.client.close() self.tlocal.client = None