Source code for gnes.client.cli

import sys
import time
import zipfile
from typing import Iterator, Callable

from termcolor import colored

from .base import GrpcClient
from ..proto import RequestGenerator, gnes_pb2


[docs]class CLIClient(GrpcClient): def __init__(self, args, auto_start: bool = True): super().__init__(args) self._bytes_generator = self._get_bytes_generator_from_args(args) if auto_start: self.start() @staticmethod def _get_bytes_generator_from_args(args): if args.txt_file: all_bytes = (v.encode() for v in args.txt_file) elif args.image_zip_file: zipfile_ = zipfile.ZipFile(args.image_zip_file) all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist()) elif args.video_zip_file: zipfile_ = zipfile.ZipFile(args.video_zip_file) all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist()) else: all_bytes = None return all_bytes
[docs] def start(self, callback: Callable[['gnes_pb2.Message'], None] = None): try: self.call(self.args.mode, callback) except Exception as ex: self.logger.error(ex) finally: self.close()
[docs] def call(self, req_type: str, callback: Callable[['gnes_pb2.Message'], None] = None): if req_type == 'train': req_iter = RequestGenerator.train(self.bytes_generator, doc_id_start=self.args.start_doc_id, batch_size=self.args.batch_size) elif req_type == 'index': req_iter = RequestGenerator.index(self.bytes_generator, doc_id_start=self.args.start_doc_id, batch_size=self.args.batch_size) elif req_type == 'query': req_iter = (RequestGenerator.query(q, top_k=self.args.top_k) for q in self.bytes_generator) else: raise NotImplementedError with ProgressBar(task_name=self.args.mode) as p_bar: for resp in self._stub.StreamCall(req_iter): if callback: callback(resp) p_bar.update()
@property def bytes_generator(self) -> Iterator[bytes]: if self._bytes_generator: return self._bytes_generator else: raise ValueError('bytes_generator is empty or not set') @bytes_generator.setter def bytes_generator(self, bytes_gen: Iterator[bytes]): if self._bytes_generator: self.logger.warning('bytes_generator is not empty, overrided') self._bytes_generator = bytes_gen
[docs]class ProgressBar: def __init__(self, bar_len: int = 20, task_name: str = ''): self.bar_len = bar_len self.task_name = task_name
[docs] def update(self): self.num_bars += 1 sys.stdout.write('\r') elapsed = time.perf_counter() - self.start_time elapsed_str = colored('elapsed', 'yellow') speed_str = colored('speed', 'yellow') num_bars = self.num_bars % self.bar_len num_bars = self.bar_len if not num_bars and self.num_bars else max(num_bars, 1) sys.stdout.write( '{:>10} [{:<{}}] {:>8}: {:3.1f}s {:>8}: {:3.1f} batch/s'.format( colored(self.task_name, 'cyan'), colored('=' * num_bars, 'green'), self.bar_len + 9, elapsed_str, elapsed, speed_str, self.num_bars / elapsed, )) if num_bars == self.bar_len: sys.stdout.write('\n') sys.stdout.flush()
def __enter__(self): self.start_time = time.perf_counter() self.num_bars = -1 self.update() return self def __exit__(self, exc_type, exc_val, exc_tb): sys.stdout.write('\t%s\n' % colored('done!', 'green'))