diff --git a/python/edl/__init__.py b/python/edl/__init__.py index abf198b9..959839ef 100644 --- a/python/edl/__init__.py +++ b/python/edl/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from edl.collective.state import PaddleState +from edl.collective.state import State +from edl.collective.distribute_reader import Reader as DistributeReader diff --git a/python/edl/collective/distribute_reader.py b/python/edl/collective/distribute_reader.py index a0a748cf..e0f87995 100644 --- a/python/edl/collective/distribute_reader.py +++ b/python/edl/collective/distribute_reader.py @@ -14,378 +14,127 @@ from __future__ import print_function import multiprocessing -import sys -import threading -from edl.uitls import reader as edl_reader -from edl.utils import env as edl_env -from edl.utils import state as edl_state - -from edl.utils import data_server -from edl.utils import data_server_pb2 -from edl.utils import edl_process -from edl.utils import data_server_client -from edl.utils import etcd_db -from edl.utils.log_utils import logger -from edl.utils import unique_name +from edl.utils import reader as edl_reader +from edl.utils import batch_data_generator +from edl.utils import batch_data_accesser from edl.utils import exceptions +from edl.utils import unique_name +from edl.utils.log_utils import logger -class DataGenerator(edl_process.ProcessWrapper): - """ - 1. get file_list from data_server_leader - 2. parse files of file_list and put BatchData to out_quque - if reach data end, put None to out_queue. - 3. program will exit if meets any error - """ - - def __init__( - self, - reader_leader_endpoint, - reader_name, - pod_id, - all_files_list, - splitter_cls, - out_queue, - ): - super(DataGenerator, self).__init__() - - self._batch_data_id = 0 - - self._leader_endpoint = reader_leader_endpoint - self._pod_id = pod_id - self._reader_name = reader_name - - self._file_list = all_files_list - self._splitter_cls = splitter_cls - self._data_queue = out_queue - - def _get_file_list(self, timeout=60): - client = data_server_client.DataServerClient() - return client.get_file_list( - leader_endpoint=self._leader_endpoint, - reader_name=self._reader_name, - pod_id=self._pod_id, - file_list=self._file_list, - ) - - def _generate_batch_data(self): - self._batch_data_id += 1 - b = data_server_pb2.BatchData() - b.batch_data_id = self._batch_data_id - b.data = None - - return b - - def _read_batch_data(self): - b = self._generate_batch_data() - for m in self._get_file_list(): - if self._stop.set(): - break - - assert self._file_list[m.idx] == m.path - for record in self._splitter_cls(m.path): - fields = record - - assert fields[0] == m.idx - rec = data_server_pb2.Record() - rec.record_no = fields[0] - for field in fields[1:]: - rec.field_data.append(field) - - if len(b.records) >= self._batch_size: - self._data_queue.put(b) - b = self._generate_batch_data() - - if len(b.records) > 0: - self._data_queue.put(b) - - self._data_queue.put(None) - - def _worker_func(self): - try: - self._read_batch_data() - except Exception as e: - print(e, file=sys.stderr) - sys.exit(1) - - -class DataAccesser(object): +class Reader(object): def __init__( - self, - reader_leader_endpoint, - reader_name, - trainer_env, - input_queue, - out_queue, - queue_size, + self, state, file_list, file_splitter_cls, batch_size, cache_capcity=100, ): - self._reader_leader_endpoint = reader_leader_endpoint - - self._reader_name = reader_name - self._trainer_env = trainer_env - self._etcd = etcd_db.get_global_etcd( - self._trainer_env.etcd_endpoint, job_id=self._trainer_env.job_id - ) - - # BatchData - self._input_queue = input_queue - self._out_queue = out_queue - # batch_data_id => BatchData - self._cache = {} - - # pb.BatchDataRequest queue - self._req_queue = threading.Queue(queue_size) - - self._data_server = data_server.DataServer(self) - self._data_server.start() - edl_reader.save_to_etcd( - self._etcd, - reader_name=self._reader_name, - pod_id=self._trainer_env.pod_id, - data_server_endpoint=self._data_server.endpoint, - ) - - self._stop = threading.Event() - self._t_reporter = threading.Thread(target=self.report) - self._t_generater = threading.Thread(target=self.generate) - self._t_accesser = threading.Thread(target=self.access) - - self._client = data_server_client.DataServerClient() - - def start(self): - self._client.connect(self._reader_leader_endpoint) - self._t_reporter.start() - self._t_generater.start() - self._t_accesser.start() - - def _report(self, report_size=10): - """ - 1. Report BatchData index to Leader - 2. Get the BatchData index need to be processed - if there is no data, set None to req_queue - """ - batch_data_ids = [] - while not self._stop.set(): - while len(batch_data_ids) < report_size: - b = self._input_queue.pop() - if b is None: - logger.info("data read to end!") - break - batch_data_ids.append(b.batch_data_id) - with self._lock: - self._cache[b.batch_data_id] = b - - self._client.report_batch_data_meta( - reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, - pod_id=self._trainer_env.pod_id, - dataserver_endpoint=self._data_server.endpoint, - batch_data_ids=batch_data_ids, - ) - - batch_data_ids = [] - - while not self._stop.set() and len(batch_data_ids) > 0: - self._client.report_batch_data_meta( - reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, - pod_id=self._trainer_env.pod_id, - dataserver_endpoint=self._data_server.endpoint, - batch_data_ids=batch_data_ids, - ) - - self._client.reach_data_end( - reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, - pod_id=self._trainer_env.pod_id, - ) - - def _access(self): - while not self._stop.set(): - res = self._client.get_balanced_batch_data( - reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, - pod_id=self._trainer_env.pod_id, - ) - - self._req_queue.put(res) - - # data end - if res is None: - break - - def _get_batch_data(self, req): - """ - Read BatchData from local or remote by BatchDataRequest - """ - if self._trainer_env.pod_id != req.producer_pod_id: - return (req, self._client.get_batch_data(req)) - - return (req, self.get_local_batch_data(req)) - - def get_local_batch_data(self, req): - ret = [] - for batch_data_id in req.data.batch_data_ids: - with self._lock: - ret.append(self._cache.pop(batch_data_id)) - - return ret - - def _generate(self): - while not self._stop.set(): - req = self._req_queue.pop() - if req is None: - break - - ret = self._get_batch_data(req) - for b in ret: - self._out_queue.put(b) - - self._out_queue.put(None) - - def report(self): - try: - self._report() - except Exception as e: - print(e, file=sys.stderr) - sys.exit(1) - - def access(self): - try: - self._access() - except Exception as e: - print(e, file=sys.stderr) - sys.exit(1) - - def generate(self): - try: - self._generate() - except Exception as e: - print(e, file=sys.stderr) - sys.exit(1) - - -def access_batch_data( - reader_leader, - reader_name, - trainer_env, - input_queue, - out_queue, - cache_capcity, - error_queue, -): - """ - Run DataAccesser in a seperated process - """ - try: - a = DataAccesser( - reader_leader, - reader_name, - trainer_env, - input_queue, - out_queue, - cache_capcity, - ) - a.start() - except KeyboardInterrupt: - pass - except Exception: - import traceback - - error_queue.put(traceback.format_exc()) - sys.exit(1) - - -class Reader(object): - def __init__(self, file_list, file_splitter_cls, batch_size, cache_capcity=100): self._file_list = file_list - assert isinstance(self._file_list, list), "file_list must be a list" + assert isinstance(self._file_list, list), "file_list must be a list of string" + self._state = state self._name = unique_name.generator("_dist_reader_") self._cls = file_splitter_cls self._batch_size = batch_size + assert self._batch_size > 0, "batch size must > 0" self._cache_capcity = cache_capcity - # connections to data servers - self._trainer_env = edl_env.TrainerEnv() - - self._state = edl_state.load_from_etcd( - etcd_endpoints=self._trainer_env.etcd_endpoints, - job_id=self._trainer_env.job_id, - state_name=self._name, - timeout=60, - ) - - self._etcd = etcd_db.get_global_etcd( - self._trainer_env.endpoints, self._trainer_env.job_id - ) # reader meta self._reader_leader = edl_reader.load_from_ectd( self._etcd, self._trainer_env.pod_leader_id, timeout=60 ) + self._generater = None self._generater_out_queue = multiprocessing.Queue(self._cache_capcity) - self._accesser_out_queue = multiprocessing.Queue(self._cache_capcity) + self._generater_error_queue = multiprocessing.Queue() - self._generater = None self._accesser = None + self._accesser_out_queue = multiprocessing.Queue(self._cache_capcity) + self._accesser_error_queue = multiprocessing.Queue() - def stop(self): - if self._generater: - self._generater.stop() - self._generater = None + self._logger_no = 0 - if self._accesser: - self._accesser.terminate() - self._accesser.join() - self._accesser = None + def _terminate_process(self, proc): + if proc is None: + return + + proc.terminate() + proc.join() + proc = None + + def stop(self): + self._terminate_process(self._generator) + self._terminate_process(self._accesser) def __exit__(self): self.stop() - def _check_accesser(self): - if self._accesser.is_alive(): + def _check_proc(self, proc, error_queue): + if self.proc.is_alive(): return True - self._accesser.join() - exitcode = self._accesser.exitcode + self.proc.join() + exitcode = self.proc.exitcode if exitcode == 0: return False - if len(self._error_queue) > 0: - raise exceptions.EdlAccessDataError(self.error_queue[0]) - else: - raise exceptions.EdlAccessDataError( - "access process exit:{}".format(exitcode) - ) - - def __iter__(self): - self._generater = DataGenerator() + if len(error_queue) > 0: + raise exceptions.EdlDataProcessError(error_queue[0]) + return + + raise exceptions.EdlDataProcessError("process exit:{}".format(exitcode)) + + def _start_generator(self): + args = batch_data_generator.Args() + args.state = self._state + args.reader_leader_endpoint = self._reader_leader.endpoint + args.reader_name = self._reader_leader.name + args.pod_id = self._pod_id + args.all_file_list = self._file_list + args.splitter_cls = self._splitter_cls + args.out_queue = self._generater_out_queue + args.error_queue = self._generater_error_queue + args.loger_name = "{}_generator_{}.log".format(self._name, self._logger_no) + logger.debug("start generator args {}".format(args)) + + self._generator = multiprocessing.Process( + target=batch_data_generator.generate, args=args + ) self._generator.start() + def _start_accesser(self): + args = batch_data_accesser.Args() + args.reader_leader_endpoint = self._reader_leader.endpoint + args.reader_name = self._reader_leader.name + args.input_queue = self._generater_out_queue + args.trainer_env = self._trainer_env + args.out_queue = self._accesser_out_queue + args.queue_size = self._cache_capcity + args.loger_name = "{}_accesser_{}.log".format(self._name, self._logger_no) + logger.debug("start accesser args {}".format(args)) + self._accesser = multiprocessing.Process( - access_batch_data, - args=( - self._reader_leader, - self._name, - self._trainer_env, - self._generater_out_queue, - self._accesser_out_queue, - self._cache_capcity, - ), + batch_data_accesser.generate, args=(args) ) + self._accesser.start() + + def __iter__(self): + self._start_generator() + self._start_accesser() + self._logger_no += 1 + while True: - if not self._check_accesser(): + if not self._check_proc(self._accesser, self._accesser_error_queue): + break + + if not self._check_proc(self._generator, self._generater_error_queue): break try: - b = self._accesser_out_queue.pop(60) + b = self._accesser_out_queue.pop(10) except multiprocessing.Queue.Empty: continue if b is None: - logger.debug("{} reach data end".format(self._name)) + logger.info("distributed reader {} reach data end".format(self._name)) break yield {"meta": b[0], "data": b[1]} diff --git a/python/edl/utils/state.py b/python/edl/collective/state.py similarity index 89% rename from python/edl/utils/state.py rename to python/edl/collective/state.py index bb23002c..f04606f4 100644 --- a/python/edl/utils/state.py +++ b/python/edl/collective/state.py @@ -20,6 +20,8 @@ from edl.utils import string_utils from edl.utils import train_status as edl_train_status from edl.utils import unique_name +from edl.utils import env as edl_env +from edl.discovery import etcd_client class DataCheckpoint(json_serializable.Serializable): @@ -113,6 +115,9 @@ def update_current_epoch_attr(self, epoch_attr): class State(json_serializable.Serializable): def __init__(self, total_batch_size, user_defined=None): + # unique + self._name = unique_name.generator("_edl_state_") + # interface self._default = { "total_batch_size": total_batch_size, # user inputs @@ -121,11 +126,31 @@ def __init__(self, total_batch_size, user_defined=None): self._adjust_func = [] # internal - self._name = unique_name.generator("_edl_state_") self._model_path = None self._data_checkpoint = DataCheckpoint() self._train_status = TrainStatus() + self._restore_from_etcd() + + def _restore_from_etcd(self): + train_env = edl_env.TrainerEnv() + etcd = etcd_client.EtcdClient(train_env.etcd_endpoints, root=train_env.job_id) + etcd.init() + + state = load_from_etcd( + etcd=etcd, + state_name=self._state.name, + user_defined=self._user_defined, + timeout=60, + ) + + self._default = state._default + self._user_defined = state._user_defined + self._adjust_func = state._adjust_func + self._model_apth = state._model_path + self._data_checkpoint = state._data_checkpoint + self._train_status = state._train_status + def from_json(self, json_str): d = json.loads(json_str) diff --git a/python/edl/tests/unittests/CMakeLists.txt b/python/edl/tests/unittests/CMakeLists.txt index 577c07f0..aa07deab 100644 --- a/python/edl/tests/unittests/CMakeLists.txt +++ b/python/edl/tests/unittests/CMakeLists.txt @@ -74,9 +74,8 @@ endfunction() file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") #FIXME string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -LIST(REMOVE_ITEM TEST_OPS test_data_reader) -LIST(REMOVE_ITEM TEST_OPS test_train) LIST(REMOVE_ITEM TEST_OPS test_launch) +LIST(REMOVE_ITEM TEST_OPS test_train) foreach(TEST_OP ${TEST_OPS}) bash_test_modules(${TEST_OP} START_BASH etcd_test.sh ENVS "PYTHON_EXECUTABLE=${PYTHON_EXECUTABLE}") endforeach(TEST_OP) diff --git a/python/edl/tests/unittests/etcd_trainer_test_base.py b/python/edl/tests/unittests/etcd_trainer_test_base.py new file mode 100644 index 00000000..e10efc49 --- /dev/null +++ b/python/edl/tests/unittests/etcd_trainer_test_base.py @@ -0,0 +1,37 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import edl.utils.constants as constants +import edl.utils.log_utils as log_utils +import os +import unittest +from edl.discovery.etcd_client import EtcdClient + +g_etcd_endpoints = "127.0.0.1:2379" + + +class EtcdTestBase(unittest.TestCase): + def setUp(self, job_id): + log_utils.get_logger(log_level=10) + self._etcd = EtcdClient([g_etcd_endpoints], root=job_id) + self._etcd.init() + + self._old_environ = copy.copy(dict(os.environ)) + constants.clean_etcd(self._etcd) + + def tearDown(self): + os.environ.clear() + os.environ.update(self._old_environ) + constants.clean_etcd(self._etcd) diff --git a/python/edl/tests/unittests/launch_demo.py b/python/edl/tests/unittests/launch_demo.py index 9e03adb1..f7f291c0 100644 --- a/python/edl/tests/unittests/launch_demo.py +++ b/python/edl/tests/unittests/launch_demo.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import os import sys diff --git a/python/edl/tests/unittests/test_data_reader.py b/python/edl/tests/unittests/test_data_reader.py index 62e712cc..4a31ecec 100644 --- a/python/edl/tests/unittests/test_data_reader.py +++ b/python/edl/tests/unittests/test_data_reader.py @@ -13,34 +13,91 @@ # limitations under the License. import unittest -from edl.collective.data_reader import DistributedDataReader -from edl.collective.dataset import TxtFileSplitter +import os +from edl.collective import distribute_reader +from edl.collective import dataset +from edl.tests.unittests import etcd_trainer_test_base +from edl.collective import state as edl_state + + +class Args(object): + def __init__(self): + self.job_id = None + self.pod_id = None + self.global_rank = None + self.rank_in_pod = None + self.trainer_endpoints = None + self.pod_ids = None + self.gpu_id = "0" + + +class TestDataReader(etcd_trainer_test_base.EtcdTestBase): + def _init_args(self, pod_id, global_rank, rank_in_pod): + args = etcd_trainer_test_base.Args() + self.job_id = self._job_id + self.pod_id = str(pod_id) + self.global_rank = str(global_rank) + self.rank_in_pod = str(rank_in_pod) + self.trainer_endpoints = None + self.pod_ids = "0,1" + self.gpu_id = "0" + + return args + + def _update_env(self, pod_id, global_rank, rank_in_pod): + args = self._init_args(pod_id, global_rank, rank_in_pod) + proc_env = { + "PADDLE_JOB_ID": args.job_id, + "PADDLE_POD_ID": args.pod_id, + "EDL_POD_LEADER_ID": "0", + "PADDLE_ETCD_ENDPOINTS": "127.0.0.1:2379", + "PADDLE_TRAINER_ID": args.global_rank, + "PADDLE_TRAINER_RANK_IN_POD": args.rank_in_pod, + "EDL_POD_IDS": args.pod_ids, + "PADDLE_TRAINER_ENDPOINTS": args.trainer_endpoints, + "PADDLE_EDL_HDFS_HOME": "/usr/local/hadoop-2.7.7", + "PADDLE_EDL_HDFS_NAME": "", + "PADDLE_EDL_HDFS_UGI": "", + "PADDLE_EDL_HDFS_PATH": "hdfs://{}".format(args.job_id), + "PADDLE_EDL_ONLY_FOR_CE_TEST": "1", + "PADDLE_EDL_FS_CACHE": ".{}".format(args.job_id), + "PADDLE_EDL_SAVE_CHECKPOINT_INTER": "0", + "CUDA_VISIBLE_DEVICES": args.gpu_id, + } + os.environ.pop("https_proxy", None) + os.environ.pop("http_proxy", None) + os.environ.update(proc_env) -class TestDataReader(unittest.TestCase): def setUp(self): + self._job_id = "test_data_reader" + super(TestDataReader, self).setUp(self._job_id) + self._file_list = ["./data_server/a.txt", "./data_server/b.txt"] self._data = {} for idx, p in enumerate(self._file_list): - s = TxtFileSplitter(p) - for r in s: + reader = dataset.TxtFileSplitter(p) + for rec in reader: if idx not in self._data: self._data[idx] = [] - d = ((p), (r[0], r[1:])) - self._data[idx].append(d) # [(path),(rec_no, splitted_fiels)]... + self._data[idx].append(rec) def test_data_reader(self): - reader1 = DistributedDataReader( + self._update_env(pod_id="0", global_rank=0, rank_in_pod=0) + state = edl_state.PaddleState(total_batch_size=1) + reader1 = distribute_reader.Reader( + state=state, file_list=self._file_list, - file_splitter_cls=TxtFileSplitter, - splitted_data_field=["line"], + file_splitter_cls=dataset.TxtFileSplitter, batch_size=1, ) - reader2 = DistributedDataReader( + self._update_env(pod_id="1", global_rank=1, rank_in_pod=0) + state = edl_state.PaddleState(total_batch_size=1) + reader2 = distribute_reader.Reader( + state=state, file_list=self._file_list, - file_splitter_cls=TxtFileSplitter, - splitted_data_field=["line"], + file_splitter_cls=dataset.TxtFileSplitter, batch_size=1, ) diff --git a/python/edl/tests/unittests/test_state.py b/python/edl/tests/unittests/test_state.py index 9be545c5..030ce48b 100644 --- a/python/edl/tests/unittests/test_state.py +++ b/python/edl/tests/unittests/test_state.py @@ -17,7 +17,7 @@ from edl.collective import serializable from edl.tests.unittests import etcd_test_base from edl.utils import constants -from edl.utils import state as edl_state +from edl.collective import state as edl_state class UserDefined(serializable.SerializableBase): diff --git a/python/edl/tests/unittests/test_train.py b/python/edl/tests/unittests/test_train.py index a50d1c56..ff4ec9a6 100644 --- a/python/edl/tests/unittests/test_train.py +++ b/python/edl/tests/unittests/test_train.py @@ -14,53 +14,47 @@ import unittest import edl -from edl.collective.data_reader import DistributedDataReader, FileMeta -from edl.collective.dataset import TxtFileSplitter -from paddle.fluid.incubate.fleet.collective import fleet - -learning_rate = 1.0 -start_program = None -main_program = None -exe = None +from edl.tests.unittests import etcd_test_base +from edl.collective import dataset def adjust(): - learing_rate = learning_rate * edl.size() # noqa: F841 + learing_rate = 1.0 * edl.size() # noqa: F841 -class TestDataReader(unittest.TestCase): - def setUp(self): +class TestDataReader(etcd_test_base.EtcdTestBase): + def _read_data(self): self._file_list = ["./data_server/a.txt", "./data_server/b.txt"] self._data = {} for idx, p in enumerate(self._file_list): - s = TxtFileSplitter(p) - m = FileMeta() - for r in s: - if idx not in m: + reader = dataset.TxtFileSplitter(p) + for rec in reader: + if idx not in self._data: self._data[idx] = [] - record = ((p), (r[0], r[1:])) - self._data[idx].append(record) # [(path),(rec_no, splitted_fiels)]... + self._data[idx].append(rec) def _train(self, state): - print("learning_rate:", learning_rate) - reader = DistributedDataReader( + reader = edl.DistributeReader( + state=state, file_list=self._file_list, - file_splitter_cls=TxtFileSplitter, - splitted_data_field=["line"], + file_splitter_cls=dataset.TxtFileSplitter, batch_size=1, - trainer_rank=0, ) for epoch in range(state.epoch, 5): for meta, batch in reader: + print("epoch_no:", epoch) edl.notify_end_one_batch(meta, state) edl.notify_end_one_epoch(state) def test_data_reader(self): - fleet.init() - state = edl.PaddleState( - exe, start_program, main_program, optimizer=None, batch=0, epoch=0 - ) + # learning_rate = 1.0 + start_program = None + main_program = None + exe = None + optimizer = None + + state = edl.PaddleState(exe, start_program, main_program, optimizer) state.register_adjust_function([adjust]) self._train(state) diff --git a/python/edl/utils/batch_data_accesser.py b/python/edl/utils/batch_data_accesser.py new file mode 100644 index 00000000..e794ca55 --- /dev/null +++ b/python/edl/utils/batch_data_accesser.py @@ -0,0 +1,168 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import threading + +from edl.utils import reader as edl_reader +from edl.utils import data_server +from edl.utils import data_server_client +from edl.utils import etcd_db + +logger = None + + +class Args(object): + def __init__(self): + self.reader_leader_endpoint = None + self.reader_name = None + self.trainer_env = None + self.input_queue = None + self.out_queue = None + self.queue_size = None + self.error_queue = None + + +class Accesser(object): + """ + 1. get data from batch_data_generator + 2. get batch_data_meta from data_server_leader + 3. get batch_data by batch_data_meta + """ + + def __init__(self, args): + self._reader_leader_endpoint = args.reader_leader_endpoint + + self._reader_name = args.reader_name + self._trainer_env = args.trainer_env + # self._etcd = None + + # BatchData + self._input_queue = args.input_queue + self._out_queue = args.out_queue + # batch_data_id => BatchData + self._cache = {} + + # pb.BatchDataRequest queue + self._req_queue = threading.Queue(args.queue_size) + + self._data_server = None + + self._stop = threading.Event() + # self._t_reporter = threading.Thread(target=self._report) + self._t_generater = threading.Thread(target=self._generate) + self._t_accesser = threading.Thread(target=self._access) + + self._client = data_server_client.Client() + + def start(self): + try: + self._start() + finally: + self._stop.set() + self.__exit__() + + def __exit__(self): + # if self._t_reporter is not None: + # self._t_reporter.join() + + if self._t_generater is not None: + self._t_generater.join() + + if self._t_accesser is not None: + self._t_accesser.join() + + # self._t_reporter = None + self._t_accesser = None + self._t_generater = None + + def _start(self): + self._data_server = data_server.Server(self) + self._data_server.start() + + etcd = etcd_db.get_global_etcd( + self._trainer_env.etcd_endpoint, job_id=self._trainer_env.job_id + ) + + edl_reader.save_to_etcd( + etcd, + reader_name=self._reader_name, + pod_id=self._trainer_env.pod_id, + data_server_endpoint=self._data_server.endpoint, + timeout=30, + ) + + self._client.connect(self._reader_leader_endpoint) + # self._t_reporter.start() + self._t_generater.start() + self._t_accesser.start() + + def _access(self): + while not self._stop.set(): + res = self._client.get_batch_data_meta( + reader_leader_endpoint=self._reader_leader_endpoint, + reader_name=self._name, + pod_id=self._trainer_env.pod_id, + ) + + self._req_queue.put(res) + + # data end + if res is None: + break + + def _get_batch_data(self, req): + """ + Read BatchData from local or remote by BatchDataRequest + """ + if self._trainer_env.pod_id != req.producer_pod_id: + return (req, self._client.get_batch_data(req)) + + return (req, self.get_local_batch_data(req)) + + def get_local_batch_data(self, req): + ret = [] + for batch_data_id in req.data.batch_data_ids: + with self._lock: + ret.append(self._cache.pop(batch_data_id)) + + return ret + + def _generate(self): + while not self._stop.set(): + req = self._req_queue.pop() + if req is None: + break + + ret = self._get_batch_data(req) + for b in ret: + self._out_queue.put(b) + + self._out_queue.put(None) + + +def generate(args): + from edl.utils import log_utils + + global logger + logger = log_utils.get_logger(log_level=20, log_file_name=args.loger_file_name) + logger.info("args:{}".format(args)) + + try: + accesser = Accesser(args) + accesser.start() + except Exception: + import traceback + + args.error_queue.put(traceback.format_exc()) diff --git a/python/edl/utils/batch_data_generator.py b/python/edl/utils/batch_data_generator.py new file mode 100644 index 00000000..475979a2 --- /dev/null +++ b/python/edl/utils/batch_data_generator.py @@ -0,0 +1,155 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function +from edl.utils import data_server_client +from edl.utils import data_server_pb2 + + +logger = None + + +class Args(object): + def __init__(self): + self.state = None + self.reader_leader_endpoint = None + self.reader_name = None + self.pod_id = None + self.all_files_list = None + self.splitter_cls = None + self.out_queue = None + self.error_queue = None + + +class Generator(object): + """ + 1. get file_list from data_server_leader + 2. parse files of file_list and put BatchData to out_quque + if reach data end, put None to out_queue. + 3. program will exit if meets any error + """ + + def __init__(self, args): + self._state = args.state + self._batch_data_id = 0 + + self._leader_endpoint = args.reader_leader_endpoint + self._reader_name = args.reader_name + self._pod_id = args.pod_id + + self._file_list = args.all_files_list + self._splitter_cls = args.splitter_cls + self._data_queue = args.out_queue + self._batch_data_ids = [] + + def _get_file_list(self, timeout=60): + client = data_server_client.Client() + return client.get_file_list( + leader_endpoint=self._leader_endpoint, + reader_name=self._reader_name, + pod_id=self._pod_id, + file_list=self._file_list, + timeout=timeout, + ) + + def _generate_batch_data(self): + self._batch_data_id += 1 + b = data_server_pb2.BatchData() + b.batch_data_id = self._batch_data_id + b.records = None + + return b + + def _report(self, batch_data, report_size=10): + if batch_data is None: + if len(self._batch_data_ids) > 0: + self._client.report_batch_data_meta( + reader_leader_endpoint=self._reader_leader_endpoint, + reader_name=self._name, + pod_id=self._trainer_env.pod_id, + dataserver_endpoint=self._data_server.endpoint, + batch_data_ids=self._batch_data_ids, + ) + self._batch_data_ids = [] + return + + if len(self._batch_data_ids) <= report_size - 1: + self._batch_data_ids.append(batch_data.batch_data_id) + return + + self._client.report_batch_data_meta( + reader_leader_endpoint=self._reader_leader_endpoint, + reader_name=self._name, + pod_id=self._trainer_env.pod_id, + dataserver_endpoint=self._data_server.endpoint, + batch_data_ids=self._batch_data_ids, + ) + self._batch_data_ids = [] + + def _read_batch_data(self): + batch_data = self._generate_batch_data() + for ele in self._get_file_list(): + if self._stop.set(): + break + + assert self._file_list[ele.idx] == ele.path + logger.info("begin process file {}:{}".format(ele.idx, ele.path)) + + for fields in self._splitter_cls(ele.path): + rec = data_server_pb2.Record() + rec.record_no = fields[0] + assert isinstance( + rec.record_no, int + ), "first element of splitter_cls must be the record index of this file" + + # FIXME(gongwb) filter it + for field in fields[1:]: + rec.field_data.append(field) + batch_data.records.append(rec) + + if len(batch_data.records) >= self._batch_size: + yield batch_data + batch_data = self._generate_batch_data() + + if len(batch_data.records) > 0: + yield batch_data + + def read_batch_data(self): + for batch_data in self._read_batch_data(): + self._report(batch_data) + self._data_queue.put(batch_data) + + self._report(None) + self._data_queue.put(None) + self._client.reach_data_end( + reader_leader_endpoint=self._reader_leader_endpoint, + reader_name=self._name, + pod_id=self._trainer_env.pod_id, + timeout=60, + ) + + +def generate(args): + from edl.utils import log_utils + + global logger + logger = log_utils.get_logger(log_level=20, log_file_name=args.loger_file_name) + logger.info("args:{}".format(args)) + + try: + generator = Generator(args) + generator.read_batch_data() + except Exception: + import traceback + + args.error_queue.put(traceback.format_exc()) diff --git a/python/edl/utils/data_server.py b/python/edl/utils/data_server.py index 1a37ff8a..48fca30c 100644 --- a/python/edl/utils/data_server.py +++ b/python/edl/utils/data_server.py @@ -372,7 +372,7 @@ def GetFileList(self, request, context): return res -class DataServer(object): +class Server(object): def __init__(self, trainer_env, reader_name, file_list, local_reader): self._server = None self._addr = None diff --git a/python/edl/utils/data_server_client.py b/python/edl/utils/data_server_client.py index a05a25af..e0021255 100644 --- a/python/edl/utils/data_server_client.py +++ b/python/edl/utils/data_server_client.py @@ -126,7 +126,7 @@ def get_batch_data_meta( exceptions.deserialize(res.status) logger.debug( - "pod client get_balanced_batch_data meta:{}".format( + "pod client get_batch_data_meta:{}".format( pb_utils.batch_data_meta_response_to_string(res) ) ) diff --git a/python/edl/utils/exceptions.py b/python/edl/utils/exceptions.py index 981c516e..a223e46a 100644 --- a/python/edl/utils/exceptions.py +++ b/python/edl/utils/exceptions.py @@ -81,7 +81,7 @@ class EdlFileListNotMatchError(EdlException): pass -class EdlDataGenerateError(EdlException): +class EdlDataProcessError(EdlException): pass diff --git a/python/edl/utils/env.py b/python/edl/utils/job_env.py similarity index 80% rename from python/edl/utils/env.py rename to python/edl/utils/job_env.py index 0960affc..642e9a9e 100644 --- a/python/edl/utils/env.py +++ b/python/edl/utils/job_env.py @@ -174,56 +174,3 @@ def __str__(self): for k, v in six.iteritems(vars(self)): s += "{}:{} ".format(k, v) return s - - -class TrainerEnv(object): - """ - Parse all envs when edl_launch starts a trainer. - """ - - def __init__(self, args=None): - self._job_id = os.environ["PADDLE_JOB_ID"] - self._pod_id = os.environ["PADDLE_POD_ID"] - self._pod_leader_id = os.environ["EDL_POD_LEADER_ID"] - self._etcd_endpoints = os.environ["PADDLE_ETCD_ENDPOINTS"] - - self._global_rank = int(os.environ["PADDLE_TRAINER_ID"]) - self._rank_in_pod = int(os.environ["PADDLE_TRAINER_RANK_IN_POD"]) - self._trainer_endpoints = os.environ["PADDLE_TRAINER_ENDPOINTS"] - self._pod_ids = os.environ["EDL_POD_IDS"].split(",") - - @property - def pod_leader_id(self): - return self._pod_leader_id - - @property - def pod_ids(self): - return self._pod_ids - - @property - def pod_id(self): - return self._pod_id - - @property - def global_rank(self): - return self._global_rank - - @property - def rank_in_pod(self): - return self._rank_in_pod - - @property - def trainer_endpoints(self): - return self._trainer_endpoints - - @property - def size(self): - return len(self._trainer_endpoints) - - @property - def job_id(self): - return self._job_id - - @property - def etcd_endpoints(self): - return self._etcd_endpoints diff --git a/python/edl/utils/log_utils.py b/python/edl/utils/log_utils.py index 04135403..c734176c 100644 --- a/python/edl/utils/log_utils.py +++ b/python/edl/utils/log_utils.py @@ -16,9 +16,19 @@ logger = logging.getLogger("root") logger.propagate = False +g_logger_set = False -def get_logger(log_level, name="root"): +def get_logger(log_level, name="root", log_file_name=None): + global g_logger_set + global logger + if g_logger_set: + return logger + g_logger_set = True + + if log_file_name: + logging.basicConfig(filename=log_file_name) + logger = logging.getLogger(name) logger.setLevel(log_level) diff --git a/python/edl/utils/process.py b/python/edl/utils/process.py deleted file mode 100644 index 45981324..00000000 --- a/python/edl/utils/process.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import multiprocessing -import threading - -from edl.utils.log_utils import logger - - -class ProcessWrapper(object): - def __init__(self): - self._stop = None - self._lock = None - self._worker = None - - self._stop = multiprocessing.Event() - self._lock = threading.Lock() - self._worker = multiprocessing.Process(target=self._worker_func) - - def _worker_func(self): - raise NotImplementedError - - def start(self): - self._worker.start() - - def stop(self): - self._stop.set() - with self._lock: - if self._worker: - self._worker.join() - self._worker = None - - logger.info("{} exit".format(self.__class__.__name__)) - - def is_stopped(self): - with self._lock: - return self._worker is None - - def __exit__(self): - self.stop() diff --git a/python/edl/utils/reader.py b/python/edl/utils/reader.py index 3ee04b6d..e0a7b70b 100644 --- a/python/edl/utils/reader.py +++ b/python/edl/utils/reader.py @@ -21,23 +21,23 @@ class ReaderMeta(object): def __init__(self, name, pod_id, data_server_endpoint): - self._name = name - self._pod_id = pod_id - self._endpoint = data_server_endpoint + self.name = name + self.pod_id = pod_id + self.endpoint = data_server_endpoint def to_json(self): d = { - "name": self._name, - "pod_id": self._pod_id, - "endpoint": self._endpoint, + "name": self.name, + "pod_id": self.pod_id, + "endpoint": self.endpoint, } return json.dumps(d) def from_json(self, s): d = json.loads(s) - self._name = d["name"] - self._pod_id = d["pod_id"] - self._endpoint = d["endpoint"] + self.name = d["name"] + self.pod_id = d["pod_id"] + self.endpoint = d["endpoint"] def __str_(self): return self._to_json() @@ -67,7 +67,7 @@ def load_from_etcd(self, etcd, reader_name, pod_id, timeout=60): return meta -def check_dist_readers(etcd): +def check_readers(etcd): servers = etcd.get_service(constants.ETCD_READER) if len(servers) <= 0: diff --git a/python/edl/utils/trainer_env.py b/python/edl/utils/trainer_env.py new file mode 100644 index 00000000..80baa3b1 --- /dev/null +++ b/python/edl/utils/trainer_env.py @@ -0,0 +1,99 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from edl.utils import job_env + + +class TrainerEnv(object): + """ + Parse all envs when edl_launch starts a trainer. + """ + + def __init__(self, args=None): + self._job_id = os.environ["PADDLE_JOB_ID"] + self._pod_id = os.environ["PADDLE_POD_ID"] + self._pod_leader_id = os.environ["EDL_POD_LEADER_ID"] + self._etcd_endpoints = os.environ["PADDLE_ETCD_ENDPOINTS"] + + self._global_rank = int(os.environ["PADDLE_TRAINER_ID"]) + self._rank_in_pod = int(os.environ["PADDLE_TRAINER_RANK_IN_POD"]) + self._trainer_endpoints = os.environ["PADDLE_TRAINER_ENDPOINTS"] + self._pod_ids = os.environ["EDL_POD_IDS"].split(",") + self._ce_test = int(os.getenv("PADDLE_EDL_ONLY_FOR_CE_TEST", "0")) + self._get_hdfs(args) + + def _get_hdfs(self, args): + # hdfs + self._hdfs_home = job_env.get_from_dict_or_env( + args, "hdfs_home", "PADDLE_EDL_HDFS_HOME" + ) + self._hdfs_name = job_env.get_from_dict_or_env( + args, "hdfs_name", "PADDLE_EDL_HDFS_NAME" + ) + self._hdfs_path = job_env.get_from_dict_or_env( + args, "hdfs_path", "PADDLE_EDL_HDFS_PATH" + ) + self._hdfs_ugi = job_env.get_from_dict_or_env( + args, "hdfs_ugi", "PADDLE_EDL_HDFS_UGI" + ) + + # assert hdfs value + if not self._ce_test: + assert ( + len(self._hdfs_home) > 3 + and len(self._hdfs_name) > 6 + and len(self._hdfs_ugi) > 3 + and len(self._hdfs_path) > 0 + ), "hdfs environ must set" + else: + assert ( + len(self._hdfs_home) > 3 and len(self._hdfs_path) > 0 + ), "hdfs environ must set" + + @property + def pod_leader_id(self): + return self._pod_leader_id + + @property + def pod_ids(self): + return self._pod_ids + + @property + def pod_id(self): + return self._pod_id + + @property + def global_rank(self): + return self._global_rank + + @property + def rank_in_pod(self): + return self._rank_in_pod + + @property + def trainer_endpoints(self): + return self._trainer_endpoints + + @property + def size(self): + return len(self._trainer_endpoints) + + @property + def job_id(self): + return self._job_id + + @property + def etcd_endpoints(self): + return self._etcd_endpoints