import os
import threading
from concurrent.futures import ThreadPoolExecutor
import grpc
from ..client.base import ZmqClient
from ..helper import set_logger
from ..proto import gnes_pb2_grpc, add_envelope
[docs]class FrontendService:
def __init__(self, args):
if not args.proxy:
os.unsetenv('http_proxy')
os.unsetenv('https_proxy')
self.logger = set_logger(self.__class__.__name__, args.verbose)
self.server = grpc.server(
ThreadPoolExecutor(max_workers=args.max_concurrency),
options=[('grpc.max_send_message_length', args.max_message_size),
('grpc.max_receive_message_length', args.max_message_size)])
self.logger.info('start a frontend with %d workers' % args.max_concurrency)
gnes_pb2_grpc.add_GnesRPCServicer_to_server(self._Servicer(args), self.server)
self.bind_address = '{0}:{1}'.format(args.grpc_host, args.grpc_port)
self.server.add_insecure_port(self.bind_address)
self._stop_event = threading.Event()
def __enter__(self):
self.server.start()
self.logger.critical('listening at: %s' % self.bind_address)
self._stop_event.clear()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.server.stop(None)
self.stop()
[docs] def stop(self):
self._stop_event.set()
[docs] def join(self):
self._stop_event.wait()
class _Servicer(gnes_pb2_grpc.GnesRPCServicer):
def __init__(self, args):
self.args = args
self.logger = set_logger(FrontendService.__name__, args.verbose)
self.zmq_context = self.ZmqContext(args)
self.request_id_cnt = 0
self.send_recv_kwargs = dict(
check_version=self.args.check_version,
timeout=self.args.timeout,
squeeze_pb=self.args.squeeze_pb)
self.pending_request = 0
def Call(self, request, context):
with self.zmq_context as zmq_client:
zmq_client.send_message(add_envelope(request, zmq_client), **self.send_recv_kwargs)
m = zmq_client.recv_message(**self.send_recv_kwargs)
return m
def StreamCall(self, request_iterator, context):
self.pending_request = 0
def get_response(num_recv, blocked=False):
if blocked:
self.logger.info('waiting for %d responses ...' % num_recv)
for _ in range(num_recv):
if blocked or zmq_client.receiver.poll(1):
msg = zmq_client.recv_message(**self.send_recv_kwargs)
self.pending_request -= 1
yield msg
while zmq_client.receiver.poll(1):
msg = zmq_client.recv_message(**self.send_recv_kwargs)
self.pending_request -= 1
yield msg
with self.zmq_context as zmq_client:
for request in request_iterator:
self.logger.info('receive request: %s' % request.request_id)
num_recv = max(self.pending_request - self.args.max_pending_request, 1)
yield from get_response(num_recv, num_recv > 1)
self.logger.info('send new request into %d appending tasks' % (self.pending_request))
zmq_client.send_message(add_envelope(request, zmq_client), **self.send_recv_kwargs)
self.pending_request += 1
self.logger.info('all requests are sent, waiting for the responses...')
yield from get_response(self.pending_request, blocked=True)
class ZmqContext:
"""The zmq context class."""
def __init__(self, args):
self.args = args
self.tlocal = threading.local()
self.tlocal.client = None
def __enter__(self):
"""Enter the context."""
client = ZmqClient(self.args)
self.tlocal.client = client
return client
def __exit__(self, exc_type, exc_value, exc_traceback):
"""Exit the context."""
self.tlocal.client.close()
self.tlocal.client = None