import ctypes
import os
import random
from typing import List, Iterator, Tuple
from typing import Optional
import numpy as np
import zmq
from termcolor import colored
from . import gnes_pb2
from ..helper import batch_iterator, default_logger
__all__ = ['RequestGenerator', 'send_message', 'recv_message',
'blob2array', 'array2blob', 'gnes_pb2', 'add_route', 'add_version']
[docs]class RequestGenerator:
[docs] @staticmethod
def generate(data: Iterator[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.Document.TEXT,
doc_id_start: int = 0, request_id_start: int = 0,
random_doc_id: bool = False, fill_docs_for: str = 'index',
*args, **kwargs):
for pi in batch_iterator(data, batch_size):
req = gnes_pb2.Request()
req.request_id = request_id_start
for raw_bytes in pi:
d = getattr(req, fill_docs_for).docs.add()
d.doc_id = doc_id_start if not random_doc_id else random.randint(0, ctypes.c_uint(-1).value)
d.raw_bytes = raw_bytes
d.weight = 1.0
d.doc_type = doc_type
doc_id_start += 1
yield req
request_id_start += 1
[docs] @staticmethod
def index(*args, **kwargs):
yield from RequestGenerator.generate(*args, **kwargs)
[docs] @staticmethod
def train(*args, **kwargs):
yield from RequestGenerator.generate(*args, **kwargs, fill_docs_for='train')
req = gnes_pb2.Request()
req.request_id = 1
req.train.flush = True
yield req
[docs] @staticmethod
def query(query: bytes, top_k: int, doc_type: int = gnes_pb2.Document.TEXT, request_id_start: int = 0, *args,
**kwargs):
if top_k <= 0:
raise ValueError('"top_k: %d" is not a valid number' % top_k)
req = gnes_pb2.Request()
req.request_id = request_id_start
req.search.query.raw_bytes = query
req.search.query.doc_type = doc_type
req.search.top_k = top_k
yield req
[docs]def blob2array(blob: 'gnes_pb2.NdArray') -> np.ndarray:
"""
Convert a blob proto to an array.
"""
x = np.frombuffer(blob.data, dtype=blob.dtype).copy()
return x.reshape(blob.shape)
[docs]def array2blob(x: np.ndarray) -> 'gnes_pb2.NdArray':
"""Converts a N-dimensional array to blob proto.
"""
blob = gnes_pb2.NdArray()
blob.data = x.tobytes()
blob.shape.extend(list(x.shape))
blob.dtype = x.dtype.name
return blob
def router2str(m: 'gnes_pb2.Message') -> str:
route_str = [r.service for r in m.envelope.routes]
return colored('â–¸', 'green').join(route_str)
[docs]def add_route(evlp: 'gnes_pb2.Envelope', name: str, identity: str):
r = evlp.routes.add()
r.service = name
r.start_time.GetCurrentTime()
r.service_identity = identity
[docs]def add_version(evlp: 'gnes_pb2.Envelope'):
from .. import __version__, __proto_version__
evlp.version.gnes = __version__
evlp.version.proto = __proto_version__
evlp.version.vcs = os.environ.get('GNES_VCS_VERSION', '')
def merge_routes(msg: 'gnes_pb2.Message', prev_msgs: List['gnes_pb2.Message']):
# take unique routes by service identity
routes = {(r.service + r.service_identity): r for m in prev_msgs for r in m.envelope.routes}
msg.envelope.ClearField('routes')
msg.envelope.routes.extend(sorted(routes.values(), key=lambda x: (x.start_time.seconds, x.start_time.nanos)))
def remove_envelope(m: 'gnes_pb2.Message'):
body = getattr(m, m.WhichOneof('body'))
body.request_id = m.envelope.request_id
m.envelope.routes[0].end_time.GetCurrentTime()
# if self.args.route_table:
# self.logger.info('route: %s' % router2str(m))
# self.logger.info('route table: \n%s' % make_route_table(m.envelope.routes, include_frontend=True))
# if self.args.dump_route:
# self.args.dump_route.write(MessageToJson(m.envelope, indent=0).replace('\n', '') + '\n')
# self.args.dump_route.flush()
return body
def add_envelope(body: 'gnes_pb2.Request', zmq_client: 'ZmqClient', cur_service='FrontendService'):
msg = gnes_pb2.Message()
msg.envelope.client_id = zmq_client.args.identity
if body.request_id is not None:
msg.envelope.request_id = body.request_id
else:
raise AttributeError('"request_id" is missing or unset!')
msg.envelope.part_id = 1
msg.envelope.num_part.append(1)
msg.envelope.timeout = 5000
add_version(msg.envelope)
add_route(msg.envelope, cur_service, zmq_client.args.identity)
msg.request.CopyFrom(body)
return msg
def make_route_table(routes, include_frontend: bool = False, jitter: float = 1e-8):
def get_duration(start_time, end_time):
if not start_time or not end_time:
return -1
d_s = end_time.seconds - start_time.seconds
d_n = end_time.nanos - start_time.nanos
if d_s < 0 and d_n > 0:
d_s = max(d_s + 1, 0)
d_n = max(d_n - 1e9, 0)
elif d_s > 0 and d_n < 0:
d_s = max(d_s - 1, 0)
d_n = max(d_n + 1e9, 0)
return max(d_s + d_n / 1e9, 0)
route_time = []
if include_frontend:
total_duration = get_duration(routes[0].start_time, routes[0].end_time) + jitter
else:
total_duration = get_duration(routes[0].start_time, routes[-1].end_time) + jitter
sum_duration = 0
for k in routes:
if k.service == 'FrontEndService':
continue
d = get_duration(k.start_time, k.end_time)
route_time.append((k.service, d))
sum_duration += d
def get_table_str(time_table):
return '\n'.join(
['%40s\t%3.3fs\t%3d%%' % (k[0], k[1], k[1] / total_duration * 100) for k in
sorted(time_table, key=lambda x: x[1], reverse=True)])
summary = [('system', total_duration - sum_duration),
('total', total_duration),
('job', sum_duration),
('parallel', max(sum_duration - total_duration, 0))]
route_table = ('\n%s\n' % ('-' * 80)).join(
['%40s\t%-6s\t%3s' % ('Breakdown', 'Time', 'Percent'), get_table_str(route_time),
get_table_str(summary)])
return route_table
def check_msg_version(msg: 'gnes_pb2.Message'):
from .. import __version__, __proto_version__
if hasattr(msg.envelope, 'version'):
if not msg.envelope.version.gnes:
# only happen in unittest
default_logger.warning('incoming message contains empty "version.gnes", '
'you may ignore it in debug/unittest mode. '
'otherwise please check if frontend service set correct version')
elif __version__ != msg.envelope.version.gnes:
raise AttributeError('mismatched GNES version! '
'incoming message has GNES version %s, whereas local GNES version %s' % (
msg.envelope.version.gnes, __version__))
if not msg.envelope.version.proto:
# only happen in unittest
default_logger.warning('incoming message contains empty "version.proto", '
'you may ignore it in debug/unittest mode. '
'otherwise please check if frontend service set correct version')
elif __proto_version__ != msg.envelope.version.proto:
raise AttributeError('mismatched protobuf version! '
'incoming message has protobuf version %s, whereas local protobuf version %s' % (
msg.envelope.version.proto, __proto_version__))
if not msg.envelope.version.vcs or not os.environ.get('GNES_VCS_VERSION'):
default_logger.warning('incoming message contains empty "version.vcs", '
'you may ignore it in debug/unittest mode, '
'or if you run gnes OUTSIDE docker container where GNES_VCS_VERSION is unset'
'otherwise please check if frontend service set correct version')
elif os.environ.get('GNES_VCS_VERSION') != msg.envelope.version.vcs:
raise AttributeError('mismatched vcs version! '
'incoming message has vcs_version %s, whereas local environment vcs_version is %s' % (
msg.envelope.version.vcs, os.environ.get('GNES_VCS_VERSION')))
else:
raise AttributeError('version_check=True locally, '
'but incoming message contains no version info in its envelope. '
'the message is probably sent from a very outdated GNES version')
def extract_bytes_from_msg(msg: 'gnes_pb2.Message') -> Tuple:
doc_bytes = []
chunk_bytes = []
doc_byte_type = b''
chunk_byte_type = b''
docs = msg.request.train.docs or msg.request.index.docs or [msg.request.search.query]
# for train request
for d in docs:
# oneof raw_data {
# string raw_text = 5;
# NdArray raw_image = 6;
# NdArray raw_video = 7;
# bytes raw_bytes = 8; // for other types
# }
dtype = d.WhichOneof('raw_data') or ''
doc_byte_type = dtype.encode()
if dtype == 'raw_bytes':
doc_bytes.append(d.raw_bytes)
d.ClearField('raw_bytes')
elif dtype == 'raw_image':
doc_bytes.append(d.raw_image.data)
d.raw_image.ClearField('data')
elif dtype == 'raw_video':
doc_bytes.append(d.raw_video.data)
d.raw_video.ClearField('data')
elif dtype == 'raw_text':
doc_bytes.append(d.raw_text.encode())
d.ClearField('raw_text')
for c in d.chunks:
# oneof content {
# string text = 2;
# NdArray blob = 3;
# bytes raw = 7;
# }
chunk_bytes.append(c.embedding.data)
c.embedding.ClearField('data')
ctype = c.WhichOneof('content') or ''
chunk_byte_type = ctype.encode()
if ctype == 'raw':
chunk_bytes.append(c.raw)
c.ClearField('raw')
elif ctype == 'blob':
chunk_bytes.append(c.blob.data)
c.blob.ClearField('data')
elif ctype == 'text':
chunk_bytes.append(c.text.encode())
c.ClearField('text')
return doc_bytes, doc_byte_type, chunk_bytes, chunk_byte_type
def fill_raw_bytes_to_msg(msg: 'gnes_pb2.Message', msg_data: List[bytes]):
doc_byte_type = msg_data[2].decode()
chunk_byte_type = msg_data[3].decode()
doc_bytes_len = int(msg_data[4])
chunk_bytes_len = int(msg_data[5])
doc_bytes = msg_data[6:(6 + doc_bytes_len)]
chunk_bytes = msg_data[(6 + doc_bytes_len):]
if len(chunk_bytes) != chunk_bytes_len:
raise ValueError('"chunk_bytes_len"=%d in message, but the actual length is %d' % (
chunk_bytes_len, len(chunk_bytes)))
c_idx = 0
d_idx = 0
docs = msg.request.train.docs or msg.request.index.docs or [msg.request.search.query]
for d in docs:
if doc_bytes and doc_bytes[d_idx]:
if doc_byte_type == 'raw_bytes':
d.raw_bytes = doc_bytes[d_idx]
d_idx += 1
elif doc_byte_type == 'raw_image':
d.raw_image.data = doc_bytes[d_idx]
d_idx += 1
elif doc_byte_type == 'raw_video':
d.raw_video.data = doc_bytes[d_idx]
d_idx += 1
elif doc_byte_type == 'raw_text':
d.raw_text = doc_bytes[d_idx].decode()
d_idx += 1
for c in d.chunks:
if chunk_bytes and chunk_bytes[c_idx]:
c.embedding.data = chunk_bytes[c_idx]
c_idx += 1
if chunk_byte_type == 'raw':
c.raw = chunk_bytes[c_idx]
c_idx += 1
elif chunk_byte_type == 'blob':
c.blob.data = chunk_bytes[c_idx]
c_idx += 1
elif chunk_byte_type == 'text':
c.text = chunk_bytes[c_idx].decode()
c_idx += 1
[docs]def send_message(sock: 'zmq.Socket', msg: 'gnes_pb2.Message', timeout: int = -1,
squeeze_pb: bool = False, **kwargs) -> None:
try:
if timeout > 0:
sock.setsockopt(zmq.SNDTIMEO, timeout)
else:
sock.setsockopt(zmq.SNDTIMEO, -1)
if not squeeze_pb:
sock.send_multipart([msg.envelope.client_id.encode(), msg.SerializeToString()])
else:
doc_bytes, doc_byte_type, chunk_bytes, chunk_byte_type = extract_bytes_from_msg(msg)
# now raw_bytes are removed from message, hoping for faster de/serialization
sock.send_multipart(
[msg.envelope.client_id.encode(), # 0
msg.SerializeToString(), # 1
doc_byte_type, chunk_byte_type, # 2, 3
b'%d' % len(doc_bytes), b'%d' % len(chunk_bytes), # 4, 5
*doc_bytes, *chunk_bytes]) # 6, 7
except zmq.error.Again:
raise TimeoutError(
'cannot send message to sock %s after timeout=%dms, please check the following:'
'is the server still online? is the network broken? are "port" correct? ' % (
sock, timeout))
except Exception as ex:
raise ex
finally:
sock.setsockopt(zmq.SNDTIMEO, -1)
[docs]def recv_message(sock: 'zmq.Socket', timeout: int = -1, check_version: bool = False, **kwargs) -> Optional[
'gnes_pb2.Message']:
try:
if timeout > 0:
sock.setsockopt(zmq.RCVTIMEO, timeout)
else:
sock.setsockopt(zmq.RCVTIMEO, -1)
msg = gnes_pb2.Message()
msg_data = sock.recv_multipart()
msg.ParseFromString(msg_data[1])
if check_version:
check_msg_version(msg)
# now we have a barebone msg, we need to fill in data
if len(msg_data) > 2:
fill_raw_bytes_to_msg(msg, msg_data)
return msg
except zmq.error.Again:
raise TimeoutError(
'no response from sock %s after timeout=%dms, please check the following:'
'is the server still online? is the network broken? are "port" correct? ' % (
sock, timeout))
except Exception as ex:
raise ex
finally:
sock.setsockopt(zmq.RCVTIMEO, -1)