import ctypes
import os
import random
from typing import List, Iterator, Tuple
from typing import Optional
import numpy as np
import zmq
from termcolor import colored
from . import gnes_pb2
from ..helper import batch_iterator, default_logger
__all__ = ['RequestGenerator', 'send_message', 'recv_message',
'blob2array', 'array2blob', 'gnes_pb2', 'add_route', 'add_version']
[docs]class RequestGenerator:
[docs] @staticmethod
def index(data: Iterator[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.Document.TEXT,
doc_id_start: int = 0, request_id_start: int = 0,
random_doc_id: bool = False,
*args, **kwargs):
for pi in batch_iterator(data, batch_size):
req = gnes_pb2.Request()
req.request_id = request_id_start
for raw_bytes in pi:
d = req.index.docs.add()
d.doc_id = doc_id_start if not random_doc_id else random.randint(0, ctypes.c_uint(-1).value)
d.raw_bytes = raw_bytes
d.weight = 1.0
d.doc_type = doc_type
doc_id_start += 1
yield req
request_id_start += 1
[docs] @staticmethod
def train(data: Iterator[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.Document.TEXT,
doc_id_start: int = 0, request_id_start: int = 0,
random_doc_id: bool = False,
*args, **kwargs):
for pi in batch_iterator(data, batch_size):
req = gnes_pb2.Request()
req.request_id = request_id_start
for raw_bytes in pi:
d = req.train.docs.add()
d.doc_id = doc_id_start if not random_doc_id else random.randint(0, ctypes.c_uint(-1).value)
d.raw_bytes = raw_bytes
d.doc_type = doc_type
if not random_doc_id:
doc_id_start += 1
yield req
request_id_start += 1
req = gnes_pb2.Request()
req.request_id = request_id_start
req.train.flush = True
yield req
request_id_start += 1
[docs] @staticmethod
def query(query: bytes, top_k: int, doc_type: int = gnes_pb2.Document.TEXT, request_id_start: int = 0, *args,
**kwargs):
if top_k <= 0:
raise ValueError('"top_k: %d" is not a valid number' % top_k)
req = gnes_pb2.Request()
req.request_id = request_id_start
req.search.query.raw_bytes = query
req.search.query.doc_type = doc_type
req.search.top_k = top_k
yield req
[docs]def blob2array(blob: 'gnes_pb2.NdArray') -> np.ndarray:
"""
Convert a blob proto to an array.
"""
x = np.frombuffer(blob.data, dtype=blob.dtype).copy()
return x.reshape(blob.shape)
[docs]def array2blob(x: np.ndarray) -> 'gnes_pb2.NdArray':
"""Converts a N-dimensional array to blob proto.
"""
blob = gnes_pb2.NdArray()
blob.data = x.tobytes()
blob.shape.extend(list(x.shape))
blob.dtype = x.dtype.name
return blob
def router2str(m: 'gnes_pb2.Message') -> str:
route_str = [r.service for r in m.envelope.routes]
return colored('â–¸', 'green').join(route_str)
[docs]def add_route(evlp: 'gnes_pb2.Envelope', name: str, identity: str):
r = evlp.routes.add()
r.service = name
r.start_time.GetCurrentTime()
r.service_identity = identity
[docs]def add_version(evlp: 'gnes_pb2.Envelope'):
from .. import __version__, __proto_version__
evlp.gnes_version = __version__
evlp.proto_version = __proto_version__
evlp.vcs_version = os.environ.get('GNES_VCS_VERSION', '')
def merge_routes(msg: 'gnes_pb2.Message', prev_msgs: List['gnes_pb2.Message']):
# take unique routes by service identity
routes = {(r.service + r.service_identity): r for m in prev_msgs for r in m.envelope.routes}
msg.envelope.ClearField('routes')
msg.envelope.routes.extend(sorted(routes.values(), key=lambda x: (x.start_time.seconds, x.start_time.nanos)))
def check_msg_version(msg: 'gnes_pb2.Message'):
from .. import __version__, __proto_version__
if hasattr(msg.envelope, 'gnes_version'):
if not msg.envelope.gnes_version:
# only happen in unittest
default_logger.warning('incoming message contains empty "gnes_version", '
'you may ignore it in debug/unittest mode. '
'otherwise please check if frontend service set correct version')
elif __version__ != msg.envelope.gnes_version:
raise AttributeError('mismatched GNES version! '
'incoming message has GNES version %s, whereas local GNES version %s' % (
msg.envelope.gnes_version, __version__))
if hasattr(msg.envelope, 'proto_version'):
if not msg.envelope.proto_version:
# only happen in unittest
default_logger.warning('incoming message contains empty "proto_version", '
'you may ignore it in debug/unittest mode. '
'otherwise please check if frontend service set correct version')
elif __proto_version__ != msg.envelope.proto_version:
raise AttributeError('mismatched protobuf version! '
'incoming message has protobuf version %s, whereas local protobuf version %s' % (
msg.envelope.proto_version, __proto_version__))
if hasattr(msg.envelope, 'vcs_version'):
if not msg.envelope.vcs_version or not os.environ.get('GNES_VCS_VERSION'):
default_logger.warning('incoming message contains empty "vcs_version", '
'you may ignore it in debug/unittest mode, '
'or if you run gnes OUTSIDE docker container where GNES_VCS_VERSION is unset'
'otherwise please check if frontend service set correct version')
elif os.environ.get('GNES_VCS_VERSION') != msg.envelope.vcs_version:
raise AttributeError('mismatched vcs version! '
'incoming message has vcs_version %s, whereas local environment vcs_version is %s' % (
msg.envelope.vcs_version, os.environ.get('GNES_VCS_VERSION')))
if not hasattr(msg.envelope, 'proto_version') and not hasattr(msg.envelope, 'gnes_version'):
raise AttributeError('version_check=True locally, '
'but incoming message contains no version info in its envelope. '
'the message is probably sent from a very outdated GNES version')
def extract_bytes_from_msg(msg: 'gnes_pb2.Message') -> Tuple:
doc_bytes = []
chunk_bytes = []
doc_byte_type = b''
chunk_byte_type = b''
docs = msg.request.train.docs or msg.request.index.docs or [msg.request.search.query]
# for train request
for d in docs:
# oneof raw_data {
# string raw_text = 5;
# NdArray raw_image = 6;
# NdArray raw_video = 7;
# bytes raw_bytes = 8; // for other types
# }
dtype = d.WhichOneof('raw_data') or ''
doc_byte_type = dtype.encode()
if dtype == 'raw_bytes':
doc_bytes.append(d.raw_bytes)
d.ClearField('raw_bytes')
elif dtype == 'raw_image':
doc_bytes.append(d.raw_image.data)
d.raw_image.ClearField('data')
elif dtype == 'raw_video':
doc_bytes.append(d.raw_video.data)
d.raw_video.ClearField('data')
elif dtype == 'raw_text':
doc_bytes.append(d.raw_text.encode())
d.ClearField('raw_text')
for c in d.chunks:
# oneof content {
# string text = 2;
# NdArray blob = 3;
# bytes raw = 7;
# }
chunk_bytes.append(c.embedding.data)
c.embedding.ClearField('data')
ctype = c.WhichOneof('content') or ''
chunk_byte_type = ctype.encode()
if ctype == 'raw':
chunk_bytes.append(c.raw)
c.ClearField('raw')
elif ctype == 'blob':
chunk_bytes.append(c.blob.data)
c.blob.ClearField('data')
elif ctype == 'text':
chunk_bytes.append(c.text.encode())
c.ClearField('text')
return doc_bytes, doc_byte_type, chunk_bytes, chunk_byte_type
def fill_raw_bytes_to_msg(msg: 'gnes_pb2.Message', msg_data: List[bytes]):
doc_byte_type = msg_data[2].decode()
chunk_byte_type = msg_data[3].decode()
doc_bytes_len = int(msg_data[4])
chunk_bytes_len = int(msg_data[5])
doc_bytes = msg_data[6:(6 + doc_bytes_len)]
chunk_bytes = msg_data[(6 + doc_bytes_len):]
if len(chunk_bytes) != chunk_bytes_len:
raise ValueError('"chunk_bytes_len"=%d in message, but the actual length is %d' % (
chunk_bytes_len, len(chunk_bytes)))
c_idx = 0
d_idx = 0
docs = msg.request.train.docs or msg.request.index.docs or [msg.request.search.query]
for d in docs:
if doc_bytes and doc_bytes[d_idx]:
if doc_byte_type == 'raw_bytes':
d.raw_bytes = doc_bytes[d_idx]
d_idx += 1
elif doc_byte_type == 'raw_image':
d.raw_image.data = doc_bytes[d_idx]
d_idx += 1
elif doc_byte_type == 'raw_video':
d.raw_video.data = doc_bytes[d_idx]
d_idx += 1
elif doc_byte_type == 'raw_text':
d.raw_text = doc_bytes[d_idx].decode()
d_idx += 1
for c in d.chunks:
if chunk_bytes and chunk_bytes[c_idx]:
c.embedding.data = chunk_bytes[c_idx]
c_idx += 1
if chunk_byte_type == 'raw':
c.raw = chunk_bytes[c_idx]
c_idx += 1
elif chunk_byte_type == 'blob':
c.blob.data = chunk_bytes[c_idx]
c_idx += 1
elif chunk_byte_type == 'text':
c.text = chunk_bytes[c_idx].decode()
c_idx += 1
[docs]def send_message(sock: 'zmq.Socket', msg: 'gnes_pb2.Message', timeout: int = -1,
squeeze_pb: bool = False, **kwargs) -> None:
try:
if timeout > 0:
sock.setsockopt(zmq.SNDTIMEO, timeout)
else:
sock.setsockopt(zmq.SNDTIMEO, -1)
if not squeeze_pb:
sock.send_multipart([msg.envelope.client_id.encode(), msg.SerializeToString()])
else:
doc_bytes, doc_byte_type, chunk_bytes, chunk_byte_type = extract_bytes_from_msg(msg)
# now raw_bytes are removed from message, hoping for faster de/serialization
sock.send_multipart(
[msg.envelope.client_id.encode(), # 0
msg.SerializeToString(), # 1
doc_byte_type, chunk_byte_type, # 2, 3
b'%d' % len(doc_bytes), b'%d' % len(chunk_bytes), # 4, 5
*doc_bytes, *chunk_bytes]) # 6, 7
except zmq.error.Again:
raise TimeoutError(
'cannot send message to sock %s after timeout=%dms, please check the following:'
'is the server still online? is the network broken? are "port" correct? ' % (
sock, timeout))
except Exception as ex:
raise ex
finally:
sock.setsockopt(zmq.SNDTIMEO, -1)
[docs]def recv_message(sock: 'zmq.Socket', timeout: int = -1, check_version: bool = False, **kwargs) -> Optional[
'gnes_pb2.Message']:
try:
if timeout > 0:
sock.setsockopt(zmq.RCVTIMEO, timeout)
else:
sock.setsockopt(zmq.RCVTIMEO, -1)
msg = gnes_pb2.Message()
msg_data = sock.recv_multipart()
msg.ParseFromString(msg_data[1])
if check_version:
check_msg_version(msg)
# now we have a barebone msg, we need to fill in data
if len(msg_data) > 2:
fill_raw_bytes_to_msg(msg, msg_data)
return msg
except zmq.error.Again:
raise TimeoutError(
'no response from sock %s after timeout=%dms, please check the following:'
'is the server still online? is the network broken? are "port" correct? ' % (
sock, timeout))
except Exception as ex:
raise ex
finally:
sock.setsockopt(zmq.RCVTIMEO, -1)