From 58406fd8272cad14730dc7cbb5dfd576c48bc703 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Fri, 8 May 2020 07:29:41 +0000 Subject: [PATCH 1/9] add --- .github/issue_template.md | 2 +- .github/pull_request_template.md | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/.github/issue_template.md b/.github/issue_template.md index b0ea00dc..8dea127e 100644 --- a/.github/issue_template.md +++ b/.github/issue_template.md @@ -22,7 +22,7 @@ about: 您可以提问训练中报错、应用、出core等问题。 You could u - 复现信息:如为报错,请给出复现环境、复现步骤 - 问题描述:请详细描述您的问题,同步贴出报错信息、日志、可复现的代码片段 -Thank you for contributing to PaddlePaddle. +Thank you for contributing to EDL. Before submitting the issue, you could search issue in the github in case that there was a similar issue submitted or resolved before. If there is no solution,please make sure that this is a training issue including the following details: **System information** diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 515a4f04..08b6f28c 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,12 +1,11 @@ -# What this PR does / why we need it: +## What this PR does / why we need it: -# Which issue(s) this PR fixes: +## Which issue(s) this PR fixes: -## Fixes # +### Fixes # -# Special notes for your reviewer: +## Special notes for your reviewer: -# Does this PR introduce a user-facing change?: +## Does this PR introduce a user-facing change? - -# Additional documentation? +## Additional documentation? From 78e7c8f479e90069aff24e0832c1d5c6aa223340 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Mon, 11 May 2020 09:12:13 +0000 Subject: [PATCH 2/9] merge --- .github/issue_template.md | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/issue_template.md b/.github/issue_template.md index dbeda3fa..871957cc 100644 --- a/.github/issue_template.md +++ b/.github/issue_template.md @@ -17,7 +17,6 @@ about: 您可以提问训练中报错、应用、出core等问题。 You could u - 复现信息:如为报错,请给出复现环境、复现步骤 - 问题描述:请详细描述您的问题,同步贴出报错信息、日志、可复现的代码片段 - Thank you for contributing to EDL. Before submitting the issue, you could search the issue in the GitHub in case that there was a similar issue submitted or resolved before. If there is no solution, please make sure that this is a training issue including the following details: From 39ad87e8787229d1a19ef2dc7d34d98011e5b25e Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 23 Sep 2020 17:40:36 +0800 Subject: [PATCH 3/9] add --- python/edl/collective/distribute_reader.py | 8 +++---- python/edl/collective/state.py | 0 python/edl/tests/unittests/launch_demo.py | 2 ++ python/edl/utils/reader.py | 2 +- python/edl/utils/state.py | 25 +++++++++++++++++++++- 5 files changed, 31 insertions(+), 6 deletions(-) create mode 100644 python/edl/collective/state.py diff --git a/python/edl/collective/distribute_reader.py b/python/edl/collective/distribute_reader.py index 99840256..e6b068a2 100644 --- a/python/edl/collective/distribute_reader.py +++ b/python/edl/collective/distribute_reader.py @@ -280,14 +280,14 @@ def __init__(self, # connections to data servers self._trainer_env = edl_env.TrainerEnv() + self._etcd = etcd_db.get_global_etcd(self._trainer_env.endpoints, + self._trainer_env.job_id) + self._state = edl_state.load_from_etcd( - etcd_endpoints=self._trainer_env.etcd_endpoints, - job_id=self._trainer_env.job_id, + etcd=self._etcd, state_name=self._name, timeout=60) - self._etcd = etcd_db.get_global_etcd(self._trainer_env.endpoints, - self._trainer_env.job_id) # reader meta self._reader_leader = edl_reader.load_from_ectd( self._etcd, self._trainer_env.pod_leader_id, timeout=60) diff --git a/python/edl/collective/state.py b/python/edl/collective/state.py new file mode 100644 index 00000000..e69de29b diff --git a/python/edl/tests/unittests/launch_demo.py b/python/edl/tests/unittests/launch_demo.py index 9e03adb1..f7f291c0 100644 --- a/python/edl/tests/unittests/launch_demo.py +++ b/python/edl/tests/unittests/launch_demo.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import os import sys diff --git a/python/edl/utils/reader.py b/python/edl/utils/reader.py index 6624af0a..69abc768 100644 --- a/python/edl/utils/reader.py +++ b/python/edl/utils/reader.py @@ -66,7 +66,7 @@ def load_from_etcd(self, etcd, reader_name, pod_id, timeout=60): return meta -def check_dist_readers(etcd): +def check_readers(etcd): servers = etcd.get_service(constants.ETCD_READER) if len(servers) <= 0: diff --git a/python/edl/utils/state.py b/python/edl/utils/state.py index 4b0ba3b1..1e151705 100644 --- a/python/edl/utils/state.py +++ b/python/edl/utils/state.py @@ -20,6 +20,9 @@ from edl.utils import string_utils from edl.utils import train_status as edl_train_status from edl.utils import unique_name +from edl.utils import env as edl_env +from edl.utils import etcd_client +from edl.utils class DataCheckpoint(json_serializable.Serializable): @@ -112,6 +115,9 @@ def update_current_epoch_attr(self, epoch_attr): class State(json_serializable.Serializable): def __init__(self, total_batch_size, user_defined=None): + #unique + self._name = unique_name.generator("_edl_state_") + # interface self._default = { "total_batch_size": total_batch_size, # user inputs @@ -120,11 +126,28 @@ def __init__(self, total_batch_size, user_defined=None): self._adjust_func = [] # internal - self._name = unique_name.generator("_edl_state_") self._model_path = None self._data_checkpoint = DataCheckpoint() self._train_status = TrainStatus() + self._restore_from_etcd() + + def _restore_from_etcd(self): + train_env = edl_env.TrainerEnv() + etcd = etcd_client.EtcdClient(train_env.etcd_endpoints, root=train_env.job_id) + etcd.init() + + state = load_from_etcd(etcd=etcd, + state_name=self._state.name, + user_defined=user_defined, timeout=60) + + self._default = state._default + self._user_defined = state._user_defined + self._adjust_func = state._adjust_func + self._model_apth = state._model_path + self._data_checkpoint = state._data_checkpoint + self._train_status = state._train_status + def from_json(self, json_str): d = json.loads(json_str) From 534e4cd67b7aac253d2bf5861a8fb087a0a6b59d Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 23 Sep 2020 18:23:20 +0800 Subject: [PATCH 4/9] add --- python/edl/__init__.py | 4 + python/edl/collective/distribute_reader.py | 258 +-------------------- python/edl/collective/state.py | 234 +++++++++++++++++++ python/edl/tests/unittests/test_state.py | 2 +- python/edl/tests/unittests/test_train.py | 36 ++- python/edl/utils/batch_data_accesser.py | 188 +++++++++++++++ python/edl/utils/batch_data_generator.py | 90 +++++++ python/edl/utils/log_utils.py | 7 +- python/edl/utils/state.py | 233 ------------------- 9 files changed, 544 insertions(+), 508 deletions(-) create mode 100644 python/edl/utils/batch_data_accesser.py create mode 100644 python/edl/utils/batch_data_generator.py delete mode 100644 python/edl/utils/state.py diff --git a/python/edl/__init__.py b/python/edl/__init__.py index abf198b9..959839ef 100644 --- a/python/edl/__init__.py +++ b/python/edl/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from edl.collective.state import PaddleState +from edl.collective.state import State +from edl.collective.distribute_reader import Reader as DistributeReader diff --git a/python/edl/collective/distribute_reader.py b/python/edl/collective/distribute_reader.py index e6b068a2..2541d7b8 100644 --- a/python/edl/collective/distribute_reader.py +++ b/python/edl/collective/distribute_reader.py @@ -15,257 +15,19 @@ import multiprocessing import sys -import threading from edl.uitls import reader as edl_reader -from edl.utils import env as edl_env -from edl.utils import state as edl_state - -from edl.utils import data_server -from edl.utils import data_server_pb2 -from edl.utils import edl_process -from edl.utils import data_server_client -from edl.utils.error_utils import handle_errors_until_timeout -from edl.utils import etcd_db -from edl.utils.log_utils import logger +from edl.utils import exceptions from edl.utils import unique_name - - -class DataGenerator(edl_process.ProcessWrapper): - """ - 1. get file_list from data_server_leader - 2. parse files of file_list and put BatchData to out_quque - if reach data end, put None to out_queue. - 3. program will exit if meets any error - """ - - def __init__(self, reader_leader_endpoint, reader_name, pod_id, - all_files_list, splitter_cls, out_queue): - super(DataGenerator, self).__init__() - - self._batch_data_id = 0 - - self._leader_endpoint = reader_leader_endpoint - self._pod_id = pod_id - self._reader_name = reader_name - - self._file_list = all_files_list - self._splitter_cls = splitter_cls - self._data_queue = out_queue - - def _get_file_list(self, timeout=60): - client = data_server_client.DataServerClient() - return client.get_file_list( - leader_endpoint=self._leader_endpoint, - reader_name=self._reader_name, - pod_id=self._pod_id, - file_list=self._file_list) - - def _generate_batch_data(self): - self._batch_data_id += 1 - b = data_server_pb2.BatchData() - b.batch_data_id = self._batch_data_id - b.data = None - - return b - - def _read_batch_data(self): - b = self._generate_batch_data() - for m in self._get_file_list(): - if self._stop.set(): - break - - assert self._file_list[m.idx] == m.path - for record in self._splitter_cls(m.path): - fields = record - - assert fields[0] == m.idx - rec = data_server_pb2.Record() - rec.record_no = fields[0] - for field in fields[1:]: - rec.field_data.append(field) - - if len(b.records) >= self._batch_size: - self._data_queue.put(b) - b = self._generate_batch_data() - - if len(b.records) > 0: - self._data_queue.put(b) - - self._data_queue.put(None) - - def _worker_func(self): - try: - self._read_batch_data() - except Exception as e: - print(e, file=sys.stderr) - sys.exit(1) - - -class DataAccesser(object): - def __init__(self, reader_leader_endpoint, reader_name, trainer_env, - input_queue, out_queue, queue_size): - self._reader_leader_endpoint = reader_leader_endpoint - - self._reader_name = reader_name - self._trainer_env = trainer_env - self._etcd = etcd_db.get_global_etcd( - self._trainer_env.etcd_endpoint, job_id=self._trainer_env.job_id) - - # BatchData - self._input_queue = input_queue - self._out_queue = out_queue - # batch_data_id => BatchData - self._cache = {} - - # pb.BatchDataRequest queue - self._req_queue = threading.Queue(queue_size) - - self._data_server = data_server.DataServer(self) - self._data_server.start() - edl_reader.save_to_etcd( - self._etcd, - reader_name=self._reader_name, - pod_id=self._trainer_env.pod_id, - data_server_endpoint=self._data_server.endpoint) - - self._stop = threading.Event() - self._t_reporter = threading.Thread(target=self.report) - self._t_generater = threading.Thread(target=self.generate) - self._t_accesser = threading.Thread(target=self.access) - - self._client = data_server_client.DataServerClient() - - def start(self): - self._client.connect(self._reader_leader_endpoint) - self._t_reporter.start() - self._t_generater.start() - self._t_accesser.start() - - def _report(self, report_size=10): - """ - 1. Report BatchData index to Leader - 2. Get the BatchData index need to be processed - if there is no data, set None to req_queue - """ - batch_data_ids = [] - while not self._stop.set(): - while len(a) < report_size: - b = self._input_queue.pop() - if b is None: - logger.info("data read to end!") - break - batch_data_ids.append(b.batch_data_id) - with self._lock: - self._cache[b.batch_data_id] = b - - self._client.report_batch_data_meta( - reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, - pod_id=self._trainer_env.pod_id, - dataserver_endpoint=self._data_server.endpoint, - batch_data_ids=batch_data_ids) - - batch_data_ids = [] - - while not self._stop.set() and len(batch_data_ids) > 0: - self._client.report_batch_data_meta( - reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, - pod_id=self._trainer_env.pod_id, - dataserver_endpoint=self._data_server.endpoint, - batch_data_ids=batch_data_ids) - - self._client.reach_data_end( - reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, - pod_id=self._trainer_env.pod_id) - - def _access(self): - while not self._stop.set(): - res = self._client.get_balanced_batch_data( - reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, - pod_id=self._trainer_env.pod_id) - - self._req_queue.put(res) - - # data end - if res is None: - break - - def _get_batch_data(self, req): - """ - Read BatchData from local or remote by BatchDataRequest - """ - if self._trainer_env.pod_id != req.producer_pod_id: - return (req, self._client.get_batch_data(req)) - - return (req, self.get_local_batch_data(req)) - - def get_local_batch_data(self, req): - ret = [] - for batch_data_id in req.data.batch_data_ids: - with self._lock: - ret.append(self._cache.pop(batch_data_id)) - - return ret - - def _generate(self): - while not self._stop.set(): - req = self._req_queue.pop() - if req is None: - break - - ret = self._get_batch_data(req) - for b in ret: - self._out_queue.put(b) - - self._out_queue.put(None) - - def report(self): - try: - self._report() - except Exception as e: - print(e, file=sys.stderr) - sys.exit(1) - - def access(self): - try: - self._access() - except Exception as e: - print(e, file=sys.stderr) - sys.exit(1) - - def generate(self): - try: - self._generate() - except Exception as e: - print(e, file=sys.stderr) - sys.exit(1) - - -def access_batch_data(reader_leader, reader_name, trainer_env, input_queue, - out_queue, cache_capcity, error_queue): - """ - Run DataAccesser in a seperated process - """ - try: - a = DataAccesser(reader_leader, reader_name, trainer_env, input_queue, - out_queue, cache_capcity) - a.start() - except KeyboardInterrupt: - pass - except Exception as e: - import traceback - error_queue.put(traceback.format_exc()) - sys.exit(1) +from edl.utils.log_utils import logger class Reader(object): def __init__(self, + state, file_list, file_splitter_cls, batch_size, + #fields, cache_capcity=100): self._file_list = file_list assert isinstance(self._file_list, list), "file_list must be a list" @@ -274,20 +36,10 @@ def __init__(self, self._cls = file_splitter_cls self._batch_size = batch_size + #self._fields = fields assert self._batch_size > 0, "batch size must > 0" self._cache_capcity = cache_capcity - # connections to data servers - self._trainer_env = edl_env.TrainerEnv() - - self._etcd = etcd_db.get_global_etcd(self._trainer_env.endpoints, - self._trainer_env.job_id) - - self._state = edl_state.load_from_etcd( - etcd=self._etcd, - state_name=self._name, - timeout=60) - # reader meta self._reader_leader = edl_reader.load_from_ectd( self._etcd, self._trainer_env.pod_leader_id, timeout=60) diff --git a/python/edl/collective/state.py b/python/edl/collective/state.py index e69de29b..9c84db92 100644 --- a/python/edl/collective/state.py +++ b/python/edl/collective/state.py @@ -0,0 +1,234 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import six +from edl.utils import constants +from edl.utils import error_utils +from edl.utils import exceptions +from edl.utils import json_serializable +from edl.utils import string_utils +from edl.utils import train_status as edl_train_status +from edl.utils import unique_name +from edl.utils import env as edl_env +from edl.utils import etcd_client +from edl.utils + + +class DataCheckpoint(json_serializable.Serializable): + def __init__(self, reader_name=None, file_list=None, processed_data=None): + self.reader_name = reader_name + self.file_list = file_list + #dict, file_idx_in_file_list=>[(record_idx_begin, record_idx_end), ...] + self.processed_data = processed_data + + +class EpochAttr(json_serializable.Serializable): + def __init__(self): + self.epoch_no = None + self.world_size = None + self.step_num = None + self.avg_step_time = None + self.step_no_of_epoch = None + + +def _load_dict_of_cls_from_json(json_str, cls): + d = json.loads(json_str) + + ret = {} + for k, v in six.iteritems(d): + ret[int(k)] = cls().from_json(v) + + return ret + + +def _dump_dict_to_json(d): + ret = {} + for k, v in six.iteritems(d): + ret[int(k)] = v.to_json() + + return json.dumps(ret) + + +class TrainStatus(json_serializable.Serializable): + def __init__(self): + self._epoch_no = None # current + self.global_step_no = None # current + + self._epochs = {} # epoch_no => EpochAttr + self.status = edl_train_status.TrainStatus.INITIAL + + def to_json(self): + d = { + "_epoch_no": self._epoch_no, + "global_step_no": int(self.global_step_no), + "_epochs": _dump_dict_to_json(self._epochs), + "status": int(self.status), + } + + return json.dumps(d) + + def from_json(self, json_str): + d = json.loads(json_str) + self._epoch_no = d["_epoch_no"] + self.global_step_no = d["global_step_no"] + self.status = d["status"] + + print("d[epochs]", d["_epochs"]) + self._epochs = _load_dict_of_cls_from_json(d["_epochs"], EpochAttr) + + @property + def epoch_no(self): + return self._epoch_no + + @epoch_no.setter + def epoch_no(self, epoch_no): + assert epoch_no >= 0 + if epoch_no not in self._epochs: + self._epochs[epoch_no] = {} + self._epoch_no = epoch_no + + def get_epoch_attr(self, epoch_no): + if epoch_no not in self._epochs: + return None + return self._epochs[epoch_no] + + def update_epoch_attr(self, epoch_no, epoch_attr): + self._epochs[epoch_no] = epoch_attr + + def get_current_epoch_attr(self): + return get_epoch_attr(self._epoch_no) + + def update_current_epoch_attr(self, epoch_attr): + return self._update_epoch_attr(self._epoch_no, epoch_attr) + + +class State(json_serializable.Serializable): + def __init__(self, total_batch_size, user_defined=None): + #unique + self._name = unique_name.generator("_edl_state_") + + # interface + self._default = { + "total_batch_size": total_batch_size, # user inputs + } + self._user_defined = user_defined + self._adjust_func = [] + + # internal + self._model_path = None + self._data_checkpoint = DataCheckpoint() + self._train_status = TrainStatus() + + self._restore_from_etcd() + + def _restore_from_etcd(self): + train_env = edl_env.TrainerEnv() + etcd = etcd_client.EtcdClient(train_env.etcd_endpoints, root=train_env.job_id) + etcd.init() + + state = load_from_etcd(etcd=etcd, + state_name=self._state.name, + user_defined=self._user_defined, + timeout=60) + + self._default = state._default + self._user_defined = state._user_defined + self._adjust_func = state._adjust_func + self._model_apth = state._model_path + self._data_checkpoint = state._data_checkpoint + self._train_status = state._train_status + + def from_json(self, json_str): + d = json.loads(json_str) + + self._default = d["_default"] + if self._user_defined is not None and d["_user_defined"] is not None: + self._user_defined.from_json(d["_user_defined"]) + + self._name = d["_name"] + self._model_path = d["_model_path"] + self._data_checkpoint.from_json(d["_data_checkpoint"]) + self._train_status.from_json(d["_train_status"]) + return d + + def register_adjust_function(self, f): + self._adjust_func.append(f) + + @property + def name(self): + return self._name + + @property + def epoch_no(self): + return self._train_status.epoch_no + + @property + def step_no_of_epoch(self): + return self._train_status.get_current_epoch_attr().step_no_of_epoch + + @property + def global_step_no(self): + return self._train_status.global_step_no + + @property + def total_batch_size(self): + return self._defaults["total_batch_size"] + + @total_batch_size.setter + def total_batch_size(self, size): + self._defaults["total_batch_size"] = size + + +@error_utils.handle_errors_until_timeout +def load_from_etcd(etcd, state_name, user_defined=None, timeout=60): + value = etcd.get_value(constants.ETCD_STATE, state_name) + + if value is None: + raise exceptions.EdlTableError("key:value = {}:{}".format( + etcd.get_full_path(constants.ETCD_READER, state_name), value)) + + state = State(total_batch_size=None, user_defined=user_defined) + state.from_json(string_utils.bytes_to_string(value)) + return state + + +@error_utils.handle_errors_until_timeout +def save_to_etcd(etcd, pod_id, state, timeout=60): + leader_key = etcd.get_full_path(constants.ETCD_POD_RANK, + constants.ETCD_POD_LEADER) + state_key = etcd.get_full_path(constants.ETCD_STATE, state.name) + + etcd = etcd._etcd + status, _ = etcd.transaction( + compare=[etcd.transactions.value(leader_key) == pod_id, ], + success=[etcd.transactions.put(state_key, state.to_json()), ], + failure=[]) + + message = "pod_id:{} save_data_checkpoint status:{}".format(pod_id, status) + if not status: + raise exceptions.EdlEtcdIOError(message) + + +class PaddleState(State): + def __init__(self, + total_batch_size, + user_defined=None, + optimizer=None, + exe=None, + program=None): + super(PaddleState, self).__init__( + total_batch_size=total_batch_size, user_defined=user_defined) + self._exe = exe + self._program = program + self._optimizer = optimizer diff --git a/python/edl/tests/unittests/test_state.py b/python/edl/tests/unittests/test_state.py index ad84a76d..47b685de 100644 --- a/python/edl/tests/unittests/test_state.py +++ b/python/edl/tests/unittests/test_state.py @@ -17,7 +17,7 @@ from edl.collective import serializable from edl.tests.unittests import etcd_test_base from edl.utils import constants -from edl.utils import state as edl_state +from edl.collective import state as edl_state class UserDefined(serializable.SerializableBase): diff --git a/python/edl/tests/unittests/test_train.py b/python/edl/tests/unittests/test_train.py index add0ab79..b5a4d0a4 100644 --- a/python/edl/tests/unittests/test_train.py +++ b/python/edl/tests/unittests/test_train.py @@ -13,22 +13,14 @@ # limitations under the License. import unittest -from edl.collective.data_reader import DistributedDataReader, FileMeta -from edl.collective.dataset import TxtFileSplitter -from paddle.fluid.incubate.fleet.collective import fleet - -learning_rate = 1.0 -start_program = None -main_program = None -exe = None - +import edl +from edl.utils.log_utils import logger def adjust(): learing_rate = learning_rate * edl.size() - -class TestDataReader(unittest.TestCase): - def setUp(self): +class TestDataReader(etcd_test_base.EtcdTestBase): + def _read_data(self): self._file_list = ["./data_server/a.txt", "./data_server/b.txt"] self._data = {} for idx, p in enumerate(self._file_list): @@ -42,26 +34,30 @@ def setUp(self): record) #[(path),(rec_no, splitted_fiels)]... def _train(self, state): - print("learning_rate:", learning_rate) - reader = DistributedDataReader( + learning_rate = 1.0 + start_program = None + main_program = None + exe = None + + reader = edl.DistributeReader( + state=state, file_list=self._file_list, file_splitter_cls=TxtFileSplitter, - splitted_data_field=["line"], - batch_size=1, - trainer_rank=0) + batch_size=1) for epoch in range(state.epoch, 5): for meta, batch in reader: + print("epoch_no:",epoch) edl.notify_end_one_batch(meta, state) edl.notify_end_one_epoch(state) def test_data_reader(self): - fleet.init() state = edl.PaddleState( - exe, start_program, main_program, optimizer, batch=0, epoch=0) + exe, start_program, main_program, optimizer, max_epoch_num=5) state.register_adjust_function([adjust]) - train(state) + self._train(state) if __name__ == '__main__': + unittest.main() diff --git a/python/edl/utils/batch_data_accesser.py b/python/edl/utils/batch_data_accesser.py new file mode 100644 index 00000000..4a0e70a1 --- /dev/null +++ b/python/edl/utils/batch_data_accesser.py @@ -0,0 +1,188 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import sys +import threading +from edl.uitls import reader as edl_reader +from edl.utils import data_server +from edl.utils import data_server_client +from edl.utils import etcd_db +from edl.utils.log_utils import logger + + +class Accesser(object): + """ + 1. get data from batch_data_generator + 2. report batch_data_meta to data_server_leader + 3. get batch_data_meta from data_server_leader + 4. get batch_data by batch_data_meta + """ + def __init__(self, reader_leader_endpoint, reader_name, trainer_env, + input_queue, out_queue, queue_size): + self._reader_leader_endpoint = reader_leader_endpoint + + self._reader_name = reader_name + self._trainer_env = trainer_env + self._etcd = etcd_db.get_global_etcd( + self._trainer_env.etcd_endpoint, job_id=self._trainer_env.job_id) + + # BatchData + self._input_queue = input_queue + self._out_queue = out_queue + # batch_data_id => BatchData + self._cache = {} + + # pb.BatchDataRequest queue + self._req_queue = threading.Queue(queue_size) + + self._data_server = data_server.DataServer(self) + self._data_server.start() + edl_reader.save_to_etcd( + self._etcd, + reader_name=self._reader_name, + pod_id=self._trainer_env.pod_id, + data_server_endpoint=self._data_server.endpoint) + + self._stop = threading.Event() + self._t_reporter = threading.Thread(target=self.report) + self._t_generater = threading.Thread(target=self.generate) + self._t_accesser = threading.Thread(target=self.access) + + self._client = data_server_client.DataServerClient() + + def start(self): + self._client.connect(self._reader_leader_endpoint) + self._t_reporter.start() + self._t_generater.start() + self._t_accesser.start() + + def _report(self, report_size=10): + """ + 1. Report BatchData index to Leader + 2. Get the BatchData index need to be processed + if there is no data, set None to req_queue + """ + batch_data_ids = [] + while not self._stop.set(): + while len(a) < report_size: + b = self._input_queue.pop() + if b is None: + logger.info("data read to end!") + break + batch_data_ids.append(b.batch_data_id) + with self._lock: + self._cache[b.batch_data_id] = b + + self._client.report_batch_data_meta( + reader_leader_endpoint=self._reader_leader_endpoint, + reader_name=self._name, + pod_id=self._trainer_env.pod_id, + dataserver_endpoint=self._data_server.endpoint, + batch_data_ids=batch_data_ids) + + batch_data_ids = [] + + while not self._stop.set() and len(batch_data_ids) > 0: + self._client.report_batch_data_meta( + reader_leader_endpoint=self._reader_leader_endpoint, + reader_name=self._name, + pod_id=self._trainer_env.pod_id, + dataserver_endpoint=self._data_server.endpoint, + batch_data_ids=batch_data_ids) + + self._client.reach_data_end( + reader_leader_endpoint=self._reader_leader_endpoint, + reader_name=self._name, + pod_id=self._trainer_env.pod_id) + + def _access(self): + while not self._stop.set(): + res = self._client.get_balanced_batch_data( + reader_leader_endpoint=self._reader_leader_endpoint, + reader_name=self._name, + pod_id=self._trainer_env.pod_id) + + self._req_queue.put(res) + + # data end + if res is None: + break + + def _get_batch_data(self, req): + """ + Read BatchData from local or remote by BatchDataRequest + """ + if self._trainer_env.pod_id != req.producer_pod_id: + return (req, self._client.get_batch_data(req)) + + return (req, self.get_local_batch_data(req)) + + def get_local_batch_data(self, req): + ret = [] + for batch_data_id in req.data.batch_data_ids: + with self._lock: + ret.append(self._cache.pop(batch_data_id)) + + return ret + + def _generate(self): + while not self._stop.set(): + req = self._req_queue.pop() + if req is None: + break + + ret = self._get_batch_data(req) + for b in ret: + self._out_queue.put(b) + + self._out_queue.put(None) + + def report(self): + try: + self._report() + except Exception as e: + print(e, file=sys.stderr) + sys.exit(1) + + def access(self): + try: + self._access() + except Exception as e: + print(e, file=sys.stderr) + sys.exit(1) + + def generate(self): + try: + self._generate() + except Exception as e: + print(e, file=sys.stderr) + sys.exit(1) + + +def access(reader_leader, reader_name, trainer_env, input_queue, + out_queue, cache_capcity, error_queue): + """ + Run DataAccesser in a seperated process + """ + try: + a = DataAccesser(reader_leader, reader_name, trainer_env, input_queue, + out_queue, cache_capcity) + a.start() + except KeyboardInterrupt: + pass + except Exception as e: + import traceback + error_queue.put(traceback.format_exc()) + sys.exit(1) \ No newline at end of file diff --git a/python/edl/utils/batch_data_generator.py b/python/edl/utils/batch_data_generator.py new file mode 100644 index 00000000..4bf56e2a --- /dev/null +++ b/python/edl/utils/batch_data_generator.py @@ -0,0 +1,90 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import sys +from edl.utils import data_server_client +from edl.utils import data_server_pb2 +from edl.utils import edl_process + + +class Generator(edl_process.ProcessWrapper): + """ + 1. get file_list from data_server_leader + 2. parse files of file_list and put BatchData to out_quque + if reach data end, put None to out_queue. + 3. program will exit if meets any error + """ + + def __init__(self, reader_leader_endpoint, reader_name, pod_id, + all_files_list, splitter_cls, out_queue): + super(DataGenerator, self).__init__() + + self._batch_data_id = 0 + + self._leader_endpoint = reader_leader_endpoint + self._pod_id = pod_id + self._reader_name = reader_name + + self._file_list = all_files_list + self._splitter_cls = splitter_cls + self._data_queue = out_queue + + def _get_file_list(self, timeout=60): + client = data_server_client.DataServerClient() + return client.get_file_list( + leader_endpoint=self._leader_endpoint, + reader_name=self._reader_name, + pod_id=self._pod_id, + file_list=self._file_list) + + def _generate_batch_data(self): + self._batch_data_id += 1 + b = data_server_pb2.BatchData() + b.batch_data_id = self._batch_data_id + b.data = None + + return b + + def _read_batch_data(self): + b = self._generate_batch_data() + for m in self._get_file_list(): + if self._stop.set(): + break + + assert self._file_list[m.idx] == m.path + for record in self._splitter_cls(m.path): + fields = record + + assert fields[0] == m.idx + rec = data_server_pb2.Record() + rec.record_no = fields[0] + for field in fields[1:]: + rec.field_data.append(field) + + if len(b.records) >= self._batch_size: + self._data_queue.put(b) + b = self._generate_batch_data() + + if len(b.records) > 0: + self._data_queue.put(b) + + self._data_queue.put(None) + + def _worker_func(self): + try: + self._read_batch_data() + except Exception as e: + print(e, file=sys.stderr) + sys.exit(1) \ No newline at end of file diff --git a/python/edl/utils/log_utils.py b/python/edl/utils/log_utils.py index c3e43ec1..d29360dc 100644 --- a/python/edl/utils/log_utils.py +++ b/python/edl/utils/log_utils.py @@ -16,9 +16,14 @@ logger = logging.getLogger("root") logger.propagate = False - +g_logger_set=False def get_logger(log_level, name="root"): + global g_logger_set + if g_logger_set: + return logger + g_logger_set = True + logger = logging.getLogger(name) logger.setLevel(log_level) diff --git a/python/edl/utils/state.py b/python/edl/utils/state.py deleted file mode 100644 index 1e151705..00000000 --- a/python/edl/utils/state.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import six -from edl.utils import constants -from edl.utils import error_utils -from edl.utils import exceptions -from edl.utils import json_serializable -from edl.utils import string_utils -from edl.utils import train_status as edl_train_status -from edl.utils import unique_name -from edl.utils import env as edl_env -from edl.utils import etcd_client -from edl.utils - - -class DataCheckpoint(json_serializable.Serializable): - def __init__(self, reader_name=None, file_list=None, processed_data=None): - self.reader_name = reader_name - self.file_list = file_list - #dict, file_idx_in_file_list=>[(record_idx_begin, record_idx_end), ...] - self.processed_data = processed_data - - -class EpochAttr(json_serializable.Serializable): - def __init__(self): - self.epoch_no = None - self.world_size = None - self.step_num = None - self.avg_step_time = None - self.step_no_of_epoch = None - - -def _load_dict_of_cls_from_json(json_str, cls): - d = json.loads(json_str) - - ret = {} - for k, v in six.iteritems(d): - ret[int(k)] = cls().from_json(v) - - return ret - - -def _dump_dict_to_json(d): - ret = {} - for k, v in six.iteritems(d): - ret[int(k)] = v.to_json() - - return json.dumps(ret) - - -class TrainStatus(json_serializable.Serializable): - def __init__(self): - self._epoch_no = None # current - self.global_step_no = None # current - - self._epochs = {} # epoch_no => EpochAttr - self.status = edl_train_status.TrainStatus.INITIAL - - def to_json(self): - d = { - "_epoch_no": self._epoch_no, - "global_step_no": int(self.global_step_no), - "_epochs": _dump_dict_to_json(self._epochs), - "status": int(self.status), - } - - return json.dumps(d) - - def from_json(self, json_str): - d = json.loads(json_str) - self._epoch_no = d["_epoch_no"] - self.global_step_no = d["global_step_no"] - self.status = d["status"] - - print("d[epochs]", d["_epochs"]) - self._epochs = _load_dict_of_cls_from_json(d["_epochs"], EpochAttr) - - @property - def epoch_no(self): - return self._epoch_no - - @epoch_no.setter - def epoch_no(self, epoch_no): - assert epoch_no >= 0 - if epoch_no not in self._epochs: - self._epochs[epoch_no] = {} - self._epoch_no = epoch_no - - def get_epoch_attr(self, epoch_no): - if epoch_no not in self._epochs: - return None - return self._epochs[epoch_no] - - def update_epoch_attr(self, epoch_no, epoch_attr): - self._epochs[epoch_no] = epoch_attr - - def get_current_epoch_attr(self): - return get_epoch_attr(self._epoch_no) - - def update_current_epoch_attr(self, epoch_attr): - return self._update_epoch_attr(self._epoch_no, epoch_attr) - - -class State(json_serializable.Serializable): - def __init__(self, total_batch_size, user_defined=None): - #unique - self._name = unique_name.generator("_edl_state_") - - # interface - self._default = { - "total_batch_size": total_batch_size, # user inputs - } - self._user_defined = user_defined - self._adjust_func = [] - - # internal - self._model_path = None - self._data_checkpoint = DataCheckpoint() - self._train_status = TrainStatus() - - self._restore_from_etcd() - - def _restore_from_etcd(self): - train_env = edl_env.TrainerEnv() - etcd = etcd_client.EtcdClient(train_env.etcd_endpoints, root=train_env.job_id) - etcd.init() - - state = load_from_etcd(etcd=etcd, - state_name=self._state.name, - user_defined=user_defined, timeout=60) - - self._default = state._default - self._user_defined = state._user_defined - self._adjust_func = state._adjust_func - self._model_apth = state._model_path - self._data_checkpoint = state._data_checkpoint - self._train_status = state._train_status - - def from_json(self, json_str): - d = json.loads(json_str) - - self._default = d["_default"] - if self._user_defined is not None and d["_user_defined"] is not None: - self._user_defined.from_json(d["_user_defined"]) - - self._name = d["_name"] - self._model_path = d["_model_path"] - self._data_checkpoint.from_json(d["_data_checkpoint"]) - self._train_status.from_json(d["_train_status"]) - return d - - def register_adjust_function(self, f): - self._adjust_func.append(f) - - @property - def name(self): - return self._name - - @property - def epoch_no(self): - return self._train_status.epoch_no - - @property - def step_no_of_epoch(self): - return self._train_status.get_current_epoch_attr().step_no_of_epoch - - @property - def global_step_no(self): - return self._train_status.global_step_no - - @property - def total_batch_size(self): - return self._defaults["total_batch_size"] - - @total_batch_size.setter - def total_batch_size(self, size): - self._defaults["total_batch_size"] = size - - -@error_utils.handle_errors_until_timeout -def load_from_etcd(etcd, state_name, user_defined=None, timeout=60): - value = etcd.get_value(constants.ETCD_STATE, state_name) - - if value is None: - raise exceptions.EdlTableError("key:value = {}:{}".format( - etcd.get_full_path(constants.ETCD_READER, state_name), value)) - - state = State(total_batch_size=None, user_defined=user_defined) - state.from_json(string_utils.bytes_to_string(value)) - return state - - -@error_utils.handle_errors_until_timeout -def save_to_etcd(etcd, pod_id, state, timeout=60): - leader_key = etcd.get_full_path(constants.ETCD_POD_RANK, - constants.ETCD_POD_LEADER) - state_key = etcd.get_full_path(constants.ETCD_STATE, state.name) - - etcd = etcd._etcd - status, _ = etcd.transaction( - compare=[etcd.transactions.value(leader_key) == pod_id, ], - success=[etcd.transactions.put(state_key, state.to_json()), ], - failure=[]) - - message = "pod_id:{} save_data_checkpoint status:{}".format(pod_id, status) - if not status: - raise exceptions.EdlEtcdIOError(message) - - -class PaddleState(State): - def __init__(self, - total_batch_size, - user_defined=None, - optimizer=None, - exe=None, - program=None): - super(PaddleState, self).__init__( - total_batch_size=total_batch_size, user_defined=user_defined) - self._exe = exe - self._program = program - self._optimizer = optimizer From ca8f5ff00d6f990e0fde2ecf3af8c0691ba84865 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 23 Sep 2020 20:51:37 +0800 Subject: [PATCH 5/9] add --- python/edl/collective/distribute_reader.py | 3 +- python/edl/utils/batch_data_accesser.py | 109 ++++++++------------- python/edl/utils/batch_data_generator.py | 109 ++++++++++++++++----- python/edl/utils/data_server.py | 2 +- python/edl/utils/log_utils.py | 5 +- python/edl/utils/process.py | 23 ++--- 6 files changed, 139 insertions(+), 112 deletions(-) diff --git a/python/edl/collective/distribute_reader.py b/python/edl/collective/distribute_reader.py index 2541d7b8..0029388a 100644 --- a/python/edl/collective/distribute_reader.py +++ b/python/edl/collective/distribute_reader.py @@ -19,6 +19,7 @@ from edl.utils import exceptions from edl.utils import unique_name from edl.utils.log_utils import logger +from edl.utils import batch_data_generator class Reader(object): @@ -79,7 +80,7 @@ def _check_accesser(self): "access process exit:{}".format(exitcode)) def __iter__(self): - self._generater = DataGenerator() + self._generater = multiprocessing.Process(target=batch_data_generator.generate, args=args) self._generator.start() self._accesser = multiprocessing.Process( diff --git a/python/edl/utils/batch_data_accesser.py b/python/edl/utils/batch_data_accesser.py index 4a0e70a1..32026fc8 100644 --- a/python/edl/utils/batch_data_accesser.py +++ b/python/edl/utils/batch_data_accesser.py @@ -35,8 +35,7 @@ def __init__(self, reader_leader_endpoint, reader_name, trainer_env, self._reader_name = reader_name self._trainer_env = trainer_env - self._etcd = etcd_db.get_global_etcd( - self._trainer_env.etcd_endpoint, job_id=self._trainer_env.job_id) + self._etcd = None # BatchData self._input_queue = input_queue @@ -47,65 +46,56 @@ def __init__(self, reader_leader_endpoint, reader_name, trainer_env, # pb.BatchDataRequest queue self._req_queue = threading.Queue(queue_size) - self._data_server = data_server.DataServer(self) - self._data_server.start() - edl_reader.save_to_etcd( - self._etcd, - reader_name=self._reader_name, - pod_id=self._trainer_env.pod_id, - data_server_endpoint=self._data_server.endpoint) + self._data_server = None self._stop = threading.Event() self._t_reporter = threading.Thread(target=self.report) self._t_generater = threading.Thread(target=self.generate) self._t_accesser = threading.Thread(target=self.access) - self._client = data_server_client.DataServerClient() + self._client = data_server_client.Client() def start(self): + try: + self._start() + finally: + self._stop.set() + self.__exit__() + + def __exit__(self): + if self._t_reporter is None: + self._t_reporter.join() + + if self._t_generater is None: + self._t_generater.join() + + if self._t_accesser is None: + self._t_accesser.join() + + self._t_reporter=None + self._t_accesser=None + self._t_generater=None + + def _start(self): + self._data_server = data_server.Server(self) + self._data_server.start() + + self._etcd = etcd_db.get_global_etcd( + self._trainer_env.etcd_endpoint, job_id=self._trainer_env.job_id) + + edl_reader.save_to_etcd( + self._etcd, + reader_name=self._reader_name, + pod_id=self._trainer_env.pod_id, + data_server_endpoint=self._data_server.endpoint, + timeout=30) + self._client.connect(self._reader_leader_endpoint) self._t_reporter.start() self._t_generater.start() self._t_accesser.start() - def _report(self, report_size=10): - """ - 1. Report BatchData index to Leader - 2. Get the BatchData index need to be processed - if there is no data, set None to req_queue - """ - batch_data_ids = [] - while not self._stop.set(): - while len(a) < report_size: - b = self._input_queue.pop() - if b is None: - logger.info("data read to end!") - break - batch_data_ids.append(b.batch_data_id) - with self._lock: - self._cache[b.batch_data_id] = b - - self._client.report_batch_data_meta( - reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, - pod_id=self._trainer_env.pod_id, - dataserver_endpoint=self._data_server.endpoint, - batch_data_ids=batch_data_ids) - batch_data_ids = [] - - while not self._stop.set() and len(batch_data_ids) > 0: - self._client.report_batch_data_meta( - reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, - pod_id=self._trainer_env.pod_id, - dataserver_endpoint=self._data_server.endpoint, - batch_data_ids=batch_data_ids) - - self._client.reach_data_end( - reader_leader_endpoint=self._reader_leader_endpoint, - reader_name=self._name, - pod_id=self._trainer_env.pod_id) def _access(self): while not self._stop.set(): @@ -149,27 +139,6 @@ def _generate(self): self._out_queue.put(None) - def report(self): - try: - self._report() - except Exception as e: - print(e, file=sys.stderr) - sys.exit(1) - - def access(self): - try: - self._access() - except Exception as e: - print(e, file=sys.stderr) - sys.exit(1) - - def generate(self): - try: - self._generate() - except Exception as e: - print(e, file=sys.stderr) - sys.exit(1) - def access(reader_leader, reader_name, trainer_env, input_queue, out_queue, cache_capcity, error_queue): @@ -177,12 +146,12 @@ def access(reader_leader, reader_name, trainer_env, input_queue, Run DataAccesser in a seperated process """ try: - a = DataAccesser(reader_leader, reader_name, trainer_env, input_queue, + a = Accesser(reader_leader, reader_name, trainer_env, input_queue, out_queue, cache_capcity) a.start() except KeyboardInterrupt: pass - except Exception as e: + except: import traceback error_queue.put(traceback.format_exc()) sys.exit(1) \ No newline at end of file diff --git a/python/edl/utils/batch_data_generator.py b/python/edl/utils/batch_data_generator.py index 4bf56e2a..f9b99920 100644 --- a/python/edl/utils/batch_data_generator.py +++ b/python/edl/utils/batch_data_generator.py @@ -12,12 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function - -import sys +import os from edl.utils import data_server_client from edl.utils import data_server_pb2 from edl.utils import edl_process +from edl.utils import log_utils + +logger = None +class Args(object): + def __init__(self): + self.state = None + self.reader_leader_endpoint= None + self.reader_name= None + self.pod_id= None + self.all_files_list= None + self.splitter_cls= None + self.out_queue= None + self.error_queue= None class Generator(edl_process.ProcessWrapper): """ @@ -27,22 +39,23 @@ class Generator(edl_process.ProcessWrapper): 3. program will exit if meets any error """ - def __init__(self, reader_leader_endpoint, reader_name, pod_id, - all_files_list, splitter_cls, out_queue): - super(DataGenerator, self).__init__() - + def __init__(self, state, reader_leader_endpoint, reader_name, pod_id, + all_files_list, splitter_cls, out_queue, error_queue): + self._state = state self._batch_data_id = 0 self._leader_endpoint = reader_leader_endpoint - self._pod_id = pod_id self._reader_name = reader_name + self._pod_id = pod_id self._file_list = all_files_list self._splitter_cls = splitter_cls self._data_queue = out_queue + self._error_queue = error_queue + self._batch_data_ids = [] def _get_file_list(self, timeout=60): - client = data_server_client.DataServerClient() + client = data_server_client.Client() return client.get_file_list( leader_endpoint=self._leader_endpoint, reader_name=self._reader_name, @@ -53,38 +66,82 @@ def _generate_batch_data(self): self._batch_data_id += 1 b = data_server_pb2.BatchData() b.batch_data_id = self._batch_data_id - b.data = None + b.records = None return b + def _report(self, batch_data, report_size=10): + if batch_data is None: + if len(self._batch_data_ids) > 0: + self._client.report_batch_data_meta( + reader_leader_endpoint=self._reader_leader_endpoint, + reader_name=self._name, + pod_id=self._trainer_env.pod_id, + dataserver_endpoint=self._data_server.endpoint, + batch_data_ids=batch_data_ids) + return + + if len(self._batch_data_ids) <= report_size - 1: + self._batch_data_ids.append(batch_data.batch_data_id) + return + + self._client.report_batch_data_meta( + reader_leader_endpoint=self._reader_leader_endpoint, + reader_name=self._name, + pod_id=self._trainer_env.pod_id, + dataserver_endpoint=self._data_server.endpoint, + batch_data_ids=self._batch_data_ids) + self._batch_data_ids = [] + def _read_batch_data(self): - b = self._generate_batch_data() - for m in self._get_file_list(): + batch_data = self._generate_batch_data() + for ele in self._get_file_list(): if self._stop.set(): break - assert self._file_list[m.idx] == m.path - for record in self._splitter_cls(m.path): + assert self._file_list[ele.idx] == ele.path + logger.info("begin process file {}:{}".format(ele.idx, ele.path)) + + for record in self._splitter_cls(ele.path): fields = record - assert fields[0] == m.idx rec = data_server_pb2.Record() rec.record_no = fields[0] + assert isinstance(rec.record_no,int), \ + "first element of splitter_cls must be the record index of this file" + + #FIXME(gongwb) filter it for field in fields[1:]: rec.field_data.append(field) + batch_data.records.append(rec) - if len(b.records) >= self._batch_size: - self._data_queue.put(b) - b = self._generate_batch_data() + if len(batch_data.records) >= self._batch_size: + yield batch_data + batch_data = self._generate_batch_data() - if len(b.records) > 0: - self._data_queue.put(b) + if len(batch_data.records) > 0: + yield batch_data - self._data_queue.put(None) + def read_batch_data(self): + for batch_data in self._read_batch_data(): + self._report(batch_data) + self._data_queue.put(batch_data) - def _worker_func(self): - try: - self._read_batch_data() - except Exception as e: - print(e, file=sys.stderr) - sys.exit(1) \ No newline at end of file + self._report(None) + self._data_queue.put(None) + self._client.reach_data_end( + reader_leader_endpoint=self._reader_leader_endpoint, + reader_name=self._name, + pod_id=self._trainer_env.pod_id) + +def generate(args) + log_file_name = "edl_{}_{}.log".format(self._class.__name__, os.getpid().log) + global logger + logger = log_utils.get_logger(log_level=20, log_file_name=log_file_name) + + cls = Generator() + try: + cls.read_batch_data() + except: + import traceback + args.error_queue.put(traceback.format_exc()) \ No newline at end of file diff --git a/python/edl/utils/data_server.py b/python/edl/utils/data_server.py index 5a514a4d..a0672a28 100644 --- a/python/edl/utils/data_server.py +++ b/python/edl/utils/data_server.py @@ -350,7 +350,7 @@ def GetFileList(self, request, context): return res -class DataServer(object): +class Server(object): def __init__(self, trainer_env, reader_name, file_list, local_reader): self._server = None self._addr = None diff --git a/python/edl/utils/log_utils.py b/python/edl/utils/log_utils.py index d29360dc..12ab389e 100644 --- a/python/edl/utils/log_utils.py +++ b/python/edl/utils/log_utils.py @@ -18,12 +18,15 @@ logger.propagate = False g_logger_set=False -def get_logger(log_level, name="root"): +def get_logger(log_level, name="root", log_file_name=None): global g_logger_set if g_logger_set: return logger g_logger_set = True + if log_file_name: + logging.basicConfig(filename=log_file_name) + logger = logging.getLogger(name) logger.setLevel(log_level) diff --git a/python/edl/utils/process.py b/python/edl/utils/process.py index 58e3dbe2..e21feb17 100644 --- a/python/edl/utils/process.py +++ b/python/edl/utils/process.py @@ -16,36 +16,33 @@ import threading from edl.utils.log_utils import logger +from edl.utils import log_utils class ProcessWrapper(object): - def __init__(self): + def __init__(self, worker_func, args): self._stop = None - self._lock = None self._worker = None + self._lock = multiprocessing.Lock() self._stop = multiprocessing.Event() - self._lock = threading.Lock() - self._worker = multiprocessing.Process(target=self._worker_func) - - def _worker_func(self): - raise NotImplementedError + self._worker = multiprocessing.Process(target=worker_func, args=args) def start(self): + log_file_name = "edl_{}_{}.log".format(self._class.__name__, os.getpid().log) + log_utils.get_logger(log_level=20, log_file_name=log_file_name) self._worker.start() def stop(self): self._stop.set() - with self._lock: - if self._worker: - self._worker.join() - self._worker = None + if self._worker: + self._worker.join() + self._worker = None logger.info("{} exit".format(self.__class__.__name__)) def is_stopped(self): - with self._lock: - return self._worker == None + return self._worker is None or not self._worker.is_alive() def __exit__(self): self.stop() From 4564c621d1e56b91d050a86d8d585aba0c9f1d4b Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 23 Sep 2020 21:02:22 +0800 Subject: [PATCH 6/9] add --- python/edl/collective/distribute_reader.py | 3 +- python/edl/utils/batch_data_accesser.py | 31 +++++++------- python/edl/utils/batch_data_generator.py | 2 +- python/edl/utils/process.py | 48 ---------------------- 4 files changed, 19 insertions(+), 65 deletions(-) delete mode 100644 python/edl/utils/process.py diff --git a/python/edl/collective/distribute_reader.py b/python/edl/collective/distribute_reader.py index 0029388a..15e47862 100644 --- a/python/edl/collective/distribute_reader.py +++ b/python/edl/collective/distribute_reader.py @@ -14,12 +14,11 @@ from __future__ import print_function import multiprocessing -import sys from edl.uitls import reader as edl_reader +from edl.utils import batch_data_generator from edl.utils import exceptions from edl.utils import unique_name from edl.utils.log_utils import logger -from edl.utils import batch_data_generator class Reader(object): diff --git a/python/edl/utils/batch_data_accesser.py b/python/edl/utils/batch_data_accesser.py index 32026fc8..e74e6acc 100644 --- a/python/edl/utils/batch_data_accesser.py +++ b/python/edl/utils/batch_data_accesser.py @@ -19,8 +19,18 @@ from edl.utils import data_server from edl.utils import data_server_client from edl.utils import etcd_db -from edl.utils.log_utils import logger +from edl.utils import log_utils +logger = None + +class Args(object): + def __init__(self): + self.reader_leader_endpoint= None + self.reader_name= None + self.trainer_env= None + self.input_queue= None + self.out_queue= None + self.queue_size= None class Accesser(object): """ @@ -63,16 +73,12 @@ def start(self): self.__exit__() def __exit__(self): - if self._t_reporter is None: - self._t_reporter.join() - if self._t_generater is None: self._t_generater.join() if self._t_accesser is None: self._t_accesser.join() - self._t_reporter=None self._t_accesser=None self._t_generater=None @@ -140,18 +146,15 @@ def _generate(self): self._out_queue.put(None) -def access(reader_leader, reader_name, trainer_env, input_queue, - out_queue, cache_capcity, error_queue): - """ - Run DataAccesser in a seperated process - """ +def generate(args): + log_file_name = "edl_data_generator_{}.log".format(os.getpid()) + global logger + logger = log_utils.get_logger(log_level=20, log_file_name=log_file_name) + try: a = Accesser(reader_leader, reader_name, trainer_env, input_queue, out_queue, cache_capcity) a.start() - except KeyboardInterrupt: - pass except: import traceback - error_queue.put(traceback.format_exc()) - sys.exit(1) \ No newline at end of file + args.error_queue.put(traceback.format_exc()) \ No newline at end of file diff --git a/python/edl/utils/batch_data_generator.py b/python/edl/utils/batch_data_generator.py index f9b99920..01812ea0 100644 --- a/python/edl/utils/batch_data_generator.py +++ b/python/edl/utils/batch_data_generator.py @@ -135,7 +135,7 @@ def read_batch_data(self): pod_id=self._trainer_env.pod_id) def generate(args) - log_file_name = "edl_{}_{}.log".format(self._class.__name__, os.getpid().log) + log_file_name = "edl_data_generator_{}.log".format(os.getpid()) global logger logger = log_utils.get_logger(log_level=20, log_file_name=log_file_name) diff --git a/python/edl/utils/process.py b/python/edl/utils/process.py deleted file mode 100644 index e21feb17..00000000 --- a/python/edl/utils/process.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import multiprocessing -import threading - -from edl.utils.log_utils import logger -from edl.utils import log_utils - - -class ProcessWrapper(object): - def __init__(self, worker_func, args): - self._stop = None - self._worker = None - - self._lock = multiprocessing.Lock() - self._stop = multiprocessing.Event() - self._worker = multiprocessing.Process(target=worker_func, args=args) - - def start(self): - log_file_name = "edl_{}_{}.log".format(self._class.__name__, os.getpid().log) - log_utils.get_logger(log_level=20, log_file_name=log_file_name) - self._worker.start() - - def stop(self): - self._stop.set() - if self._worker: - self._worker.join() - self._worker = None - - logger.info("{} exit".format(self.__class__.__name__)) - - def is_stopped(self): - return self._worker is None or not self._worker.is_alive() - - def __exit__(self): - self.stop() From 73fec297811db76661413b1b60ad4a7eff8c22a9 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Sun, 27 Sep 2020 16:07:23 +0800 Subject: [PATCH 7/9] add --- python/edl/collective/distribute_reader.py | 39 ++++++++++++++-------- python/edl/tests/unittests/test_train.py | 13 ++++---- python/edl/utils/batch_data_accesser.py | 27 ++++++++------- python/edl/utils/batch_data_generator.py | 27 +++++++-------- 4 files changed, 59 insertions(+), 47 deletions(-) diff --git a/python/edl/collective/distribute_reader.py b/python/edl/collective/distribute_reader.py index 15e47862..f04ee1b7 100644 --- a/python/edl/collective/distribute_reader.py +++ b/python/edl/collective/distribute_reader.py @@ -16,11 +16,11 @@ import multiprocessing from edl.uitls import reader as edl_reader from edl.utils import batch_data_generator +from edl.utils import batch_data_accesser from edl.utils import exceptions from edl.utils import unique_name from edl.utils.log_utils import logger - class Reader(object): def __init__(self, state, @@ -31,6 +31,7 @@ def __init__(self, cache_capcity=100): self._file_list = file_list assert isinstance(self._file_list, list), "file_list must be a list" + self._state = state self._name = unique_name.generator("_dist_reader_") @@ -52,7 +53,8 @@ def __init__(self, def stop(self): if self._generater: - self._generater.stop() + self._generater.terminate() + self._generater.join() self._generater = None if self._accesser: @@ -63,12 +65,12 @@ def stop(self): def __exit__(self): self.stop() - def _check_accesser(self): - if self._accesser.is_alive(): + def _check(self, proc): + if self.proc.is_alive(): return True - self._accesser.join() - exitcode = self._accesser.exitcode + self.proc.join() + exitcode = self.proc.exitcode if exitcode == 0: return False @@ -78,21 +80,30 @@ def _check_accesser(self): raise exceptions.EdlAccessDataError( "access process exit:{}".format(exitcode)) - def __iter__(self): - self._generater = multiprocessing.Process(target=batch_data_generator.generate, args=args) + def _start_generator(self): + args = batch_data_generator.Args() + self._generator = multiprocessing.Process(target=batch_data_generator.generate, args=args) self._generator.start() + def _start_accesser(self): + args=batch_data_accessor.Args() self._accesser = multiprocessing.Process( - access_batch_data, - args=(self._reader_leader, self._name, self._trainer_env, - self._generater_out_queue, self._accesser_out_queue, - self._cache_capcity)) + batch_data_accesser.generate, + args=(args)) + + def __iter__(self): + self._start_generator() + self._start_accesser() + while True: - if not self._check_accesser(): + if not self._check(self._accesser): + break + + if not self._check(self._generator): break try: - b = self._accesser_out_queue.pop(60) + b = self._accesser_out_queue.pop(10) except multiprocessing.Queue.Empty as e: continue diff --git a/python/edl/tests/unittests/test_train.py b/python/edl/tests/unittests/test_train.py index b5a4d0a4..b9a37dbe 100644 --- a/python/edl/tests/unittests/test_train.py +++ b/python/edl/tests/unittests/test_train.py @@ -34,11 +34,6 @@ def _read_data(self): record) #[(path),(rec_no, splitted_fiels)]... def _train(self, state): - learning_rate = 1.0 - start_program = None - main_program = None - exe = None - reader = edl.DistributeReader( state=state, file_list=self._file_list, @@ -52,8 +47,14 @@ def _train(self, state): edl.notify_end_one_epoch(state) def test_data_reader(self): + learning_rate = 1.0 + start_program = None + main_program = None + exe = None + optimizer=None + state = edl.PaddleState( - exe, start_program, main_program, optimizer, max_epoch_num=5) + exe, start_program, main_program, optimizer) state.register_adjust_function([adjust]) self._train(state) diff --git a/python/edl/utils/batch_data_accesser.py b/python/edl/utils/batch_data_accesser.py index e74e6acc..ad70050a 100644 --- a/python/edl/utils/batch_data_accesser.py +++ b/python/edl/utils/batch_data_accesser.py @@ -14,12 +14,13 @@ from __future__ import print_function import sys +import os import threading + from edl.uitls import reader as edl_reader from edl.utils import data_server from edl.utils import data_server_client from edl.utils import etcd_db -from edl.utils import log_utils logger = None @@ -39,22 +40,21 @@ class Accesser(object): 3. get batch_data_meta from data_server_leader 4. get batch_data by batch_data_meta """ - def __init__(self, reader_leader_endpoint, reader_name, trainer_env, - input_queue, out_queue, queue_size): - self._reader_leader_endpoint = reader_leader_endpoint + def __init__(self, args): + self._reader_leader_endpoint = args.reader_leader_endpoint - self._reader_name = reader_name - self._trainer_env = trainer_env + self._reader_name = argsreader_name + self._trainer_env = argstrainer_env self._etcd = None # BatchData - self._input_queue = input_queue - self._out_queue = out_queue + self._input_queue = argsinput_queue + self._out_queue = argsout_queue # batch_data_id => BatchData self._cache = {} # pb.BatchDataRequest queue - self._req_queue = threading.Queue(queue_size) + self._req_queue = threading.Queue(args.queue_size) self._data_server = None @@ -101,8 +101,6 @@ def _start(self): self._t_generater.start() self._t_accesser.start() - - def _access(self): while not self._stop.set(): res = self._client.get_balanced_batch_data( @@ -148,13 +146,14 @@ def _generate(self): def generate(args): log_file_name = "edl_data_generator_{}.log".format(os.getpid()) + from edl.utils import log_utils global logger logger = log_utils.get_logger(log_level=20, log_file_name=log_file_name) + logger.info("args:{}".format(args)) try: - a = Accesser(reader_leader, reader_name, trainer_env, input_queue, - out_queue, cache_capcity) - a.start() + accesser = Accesser(args) + accesser.start() except: import traceback args.error_queue.put(traceback.format_exc()) \ No newline at end of file diff --git a/python/edl/utils/batch_data_generator.py b/python/edl/utils/batch_data_generator.py index 01812ea0..f8ac7600 100644 --- a/python/edl/utils/batch_data_generator.py +++ b/python/edl/utils/batch_data_generator.py @@ -16,7 +16,7 @@ from edl.utils import data_server_client from edl.utils import data_server_pb2 from edl.utils import edl_process -from edl.utils import log_utils + logger = None @@ -39,19 +39,18 @@ class Generator(edl_process.ProcessWrapper): 3. program will exit if meets any error """ - def __init__(self, state, reader_leader_endpoint, reader_name, pod_id, - all_files_list, splitter_cls, out_queue, error_queue): - self._state = state + def __init__(self, args): + self._state = args.state self._batch_data_id = 0 - self._leader_endpoint = reader_leader_endpoint - self._reader_name = reader_name - self._pod_id = pod_id + self._leader_endpoint = args.reader_leader_endpoint + self._reader_name = args.reader_name + self._pod_id = args.pod_id - self._file_list = all_files_list - self._splitter_cls = splitter_cls - self._data_queue = out_queue - self._error_queue = error_queue + self._file_list = args.all_files_list + self._splitter_cls = args.splitter_cls + self._data_queue = args.out_queue + self._error_queue = args.error_queue self._batch_data_ids = [] def _get_file_list(self, timeout=60): @@ -136,12 +135,14 @@ def read_batch_data(self): def generate(args) log_file_name = "edl_data_generator_{}.log".format(os.getpid()) + from edl.utils import log_utils global logger logger = log_utils.get_logger(log_level=20, log_file_name=log_file_name) + logger.info("args:{}".format(args)) - cls = Generator() try: - cls.read_batch_data() + generator = Generator(args) + generator.read_batch_data() except: import traceback args.error_queue.put(traceback.format_exc()) \ No newline at end of file From 61e145ddfbe18bf1a236dde34b565a9d7b3e2b29 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Fri, 9 Oct 2020 04:13:59 +0000 Subject: [PATCH 8/9] add test=develop --- python/edl/collective/distribute_reader.py | 59 +++++++++++------ python/edl/utils/batch_data_accesser.py | 77 ++++++++++++---------- python/edl/utils/batch_data_generator.py | 56 +++++++++------- python/edl/utils/data_server_client.py | 2 +- python/edl/utils/exceptions.py | 2 +- python/edl/utils/reader.py | 18 ++--- 6 files changed, 124 insertions(+), 90 deletions(-) diff --git a/python/edl/collective/distribute_reader.py b/python/edl/collective/distribute_reader.py index b0b2aade..24e31028 100644 --- a/python/edl/collective/distribute_reader.py +++ b/python/edl/collective/distribute_reader.py @@ -24,23 +24,17 @@ class Reader(object): def __init__( - self, - state, - file_list, - file_splitter_cls, - batch_size, - # fields, - cache_capcity=100, + self, state, file_list, file_splitter_cls, batch_size, cache_capcity=100, ): self._file_list = file_list - assert isinstance(self._file_list, list), "file_list must be a list" + assert isinstance(self._file_list, list), "file_list must be a list of string" self._state = state self._name = unique_name.generator("_dist_reader_") self._cls = file_splitter_cls self._batch_size = batch_size - # self._fields = fields + assert self._batch_size > 0, "batch size must > 0" self._cache_capcity = cache_capcity @@ -57,16 +51,19 @@ def __init__( self._accesser_out_queue = multiprocessing.Queue(self._cache_capcity) self._accesser_error_queue = multiprocessing.Queue() - def stop(self): - if self._generater: - self._generater.terminate() - self._generater.join() - self._generater = None + self._logger_no = 0 - if self._accesser: - self._accesser.terminate() - self._accesser.join() - self._accesser = None + def _terminate_process(self, proc): + if proc is None: + return + + proc.terminate() + proc.join() + proc = None + + def stop(self): + self._terminate_process(self._generator) + self._terminate_process(self._accesser) def __exit__(self): self.stop() @@ -81,12 +78,22 @@ def _check(self, proc, error_queue): return False if len(error_queue) > 0: - raise exceptions.EdlAccessDataError(error_queue[0]) + raise exceptions.EdlDataProcessError(error_queue[0]) else: - raise exceptions.EdlAccessDataError("process exit:{}".format(exitcode)) + raise exceptions.EdlDataProcessError("process exit:{}".format(exitcode)) def _start_generator(self): args = batch_data_generator.Args() + args.state = self._state + args.reader_leader_endpoint = self._reader_leader.endpoint + args.reader_name = self._reader_leader.name + args.pod_id = self._pod_id + args.all_file_list = self._file_list + args.splitter_cls = self._splitter_cls + args.out_queue = self._generater_out_queue + args.error_queue = self._generater_error_queue + args.loger_name = "{}_generator_{}.log".format(self._name, self._logger_no) + self._generator = multiprocessing.Process( target=batch_data_generator.generate, args=args ) @@ -94,13 +101,23 @@ def _start_generator(self): def _start_accesser(self): args = batch_data_accesser.Args() + args.reader_leader_endpoint = self._reader_leader.endpoint + args.reader_name = self._reader_leader.name + args.input_queue = self._generater_out_queue + args.trainer_env = self._trainer_env + args.out_queue = self._accesser_out_queue + args.queue_size = self._cache_capcity + args.loger_name = "{}_accesser_{}.log".format(self._name, self._logger_no) + self._accesser = multiprocessing.Process( batch_data_accesser.generate, args=(args) ) + self._accesser.start() def __iter__(self): self._start_generator() self._start_accesser() + self._logger_no += 1 while True: if not self._check(self._accesser, self._accesser_error_queue): @@ -115,6 +132,6 @@ def __iter__(self): continue if b is None: - logger.info("{} reach data end".format(self._name)) + logger.info("distributed reader {} reach data end".format(self._name)) break yield {"meta": b[0], "data": b[1]} diff --git a/python/edl/utils/batch_data_accesser.py b/python/edl/utils/batch_data_accesser.py index ad70050a..ed307ad3 100644 --- a/python/edl/utils/batch_data_accesser.py +++ b/python/edl/utils/batch_data_accesser.py @@ -13,8 +13,6 @@ # limitations under the License. from __future__ import print_function -import sys -import os import threading from edl.uitls import reader as edl_reader @@ -24,32 +22,35 @@ logger = None + class Args(object): def __init__(self): - self.reader_leader_endpoint= None - self.reader_name= None - self.trainer_env= None - self.input_queue= None - self.out_queue= None - self.queue_size= None + self.reader_leader_endpoint = None + self.reader_name = None + self.trainer_env = None + self.input_queue = None + self.out_queue = None + self.queue_size = None + self.error_queue = None + class Accesser(object): """ 1. get data from batch_data_generator - 2. report batch_data_meta to data_server_leader - 3. get batch_data_meta from data_server_leader - 4. get batch_data by batch_data_meta + 2. get batch_data_meta from data_server_leader + 3. get batch_data by batch_data_meta """ + def __init__(self, args): self._reader_leader_endpoint = args.reader_leader_endpoint - self._reader_name = argsreader_name - self._trainer_env = argstrainer_env - self._etcd = None + self._reader_name = args.reader_name + self._trainer_env = args.trainer_env + # self._etcd = None # BatchData - self._input_queue = argsinput_queue - self._out_queue = argsout_queue + self._input_queue = args.input_queue + self._out_queue = args.out_queue # batch_data_id => BatchData self._cache = {} @@ -59,9 +60,9 @@ def __init__(self, args): self._data_server = None self._stop = threading.Event() - self._t_reporter = threading.Thread(target=self.report) - self._t_generater = threading.Thread(target=self.generate) - self._t_accesser = threading.Thread(target=self.access) + # self._t_reporter = threading.Thread(target=self._report) + self._t_generater = threading.Thread(target=self._generate) + self._t_accesser = threading.Thread(target=self._access) self._client = data_server_client.Client() @@ -73,40 +74,47 @@ def start(self): self.__exit__() def __exit__(self): - if self._t_generater is None: + # if self._t_reporter is not None: + # self._t_reporter.join() + + if self._t_generater is not None: self._t_generater.join() - if self._t_accesser is None: + if self._t_accesser is not None: self._t_accesser.join() - self._t_accesser=None - self._t_generater=None + # self._t_reporter = None + self._t_accesser = None + self._t_generater = None def _start(self): self._data_server = data_server.Server(self) self._data_server.start() - self._etcd = etcd_db.get_global_etcd( - self._trainer_env.etcd_endpoint, job_id=self._trainer_env.job_id) + etcd = etcd_db.get_global_etcd( + self._trainer_env.etcd_endpoint, job_id=self._trainer_env.job_id + ) edl_reader.save_to_etcd( - self._etcd, + etcd, reader_name=self._reader_name, pod_id=self._trainer_env.pod_id, data_server_endpoint=self._data_server.endpoint, - timeout=30) + timeout=30, + ) self._client.connect(self._reader_leader_endpoint) - self._t_reporter.start() + # self._t_reporter.start() self._t_generater.start() self._t_accesser.start() def _access(self): while not self._stop.set(): - res = self._client.get_balanced_batch_data( + res = self._client.get_batch_data_meta( reader_leader_endpoint=self._reader_leader_endpoint, reader_name=self._name, - pod_id=self._trainer_env.pod_id) + pod_id=self._trainer_env.pod_id, + ) self._req_queue.put(res) @@ -145,15 +153,16 @@ def _generate(self): def generate(args): - log_file_name = "edl_data_generator_{}.log".format(os.getpid()) from edl.utils import log_utils + global logger - logger = log_utils.get_logger(log_level=20, log_file_name=log_file_name) + logger = log_utils.get_logger(log_level=20, log_file_name=args.loger_file_name) logger.info("args:{}".format(args)) try: accesser = Accesser(args) accesser.start() - except: + except Exception: import traceback - args.error_queue.put(traceback.format_exc()) \ No newline at end of file + + args.error_queue.put(traceback.format_exc()) diff --git a/python/edl/utils/batch_data_generator.py b/python/edl/utils/batch_data_generator.py index f8ac7600..75427ae9 100644 --- a/python/edl/utils/batch_data_generator.py +++ b/python/edl/utils/batch_data_generator.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function -import os from edl.utils import data_server_client from edl.utils import data_server_pb2 from edl.utils import edl_process @@ -20,16 +19,18 @@ logger = None + class Args(object): def __init__(self): self.state = None - self.reader_leader_endpoint= None - self.reader_name= None - self.pod_id= None - self.all_files_list= None - self.splitter_cls= None - self.out_queue= None - self.error_queue= None + self.reader_leader_endpoint = None + self.reader_name = None + self.pod_id = None + self.all_files_list = None + self.splitter_cls = None + self.out_queue = None + self.error_queue = None + class Generator(edl_process.ProcessWrapper): """ @@ -50,7 +51,6 @@ def __init__(self, args): self._file_list = args.all_files_list self._splitter_cls = args.splitter_cls self._data_queue = args.out_queue - self._error_queue = args.error_queue self._batch_data_ids = [] def _get_file_list(self, timeout=60): @@ -59,7 +59,9 @@ def _get_file_list(self, timeout=60): leader_endpoint=self._leader_endpoint, reader_name=self._reader_name, pod_id=self._pod_id, - file_list=self._file_list) + file_list=self._file_list, + timeout=timeout, + ) def _generate_batch_data(self): self._batch_data_id += 1 @@ -77,7 +79,9 @@ def _report(self, batch_data, report_size=10): reader_name=self._name, pod_id=self._trainer_env.pod_id, dataserver_endpoint=self._data_server.endpoint, - batch_data_ids=batch_data_ids) + batch_data_ids=self._batch_data_ids, + ) + self._batch_data_ids = [] return if len(self._batch_data_ids) <= report_size - 1: @@ -89,7 +93,8 @@ def _report(self, batch_data, report_size=10): reader_name=self._name, pod_id=self._trainer_env.pod_id, dataserver_endpoint=self._data_server.endpoint, - batch_data_ids=self._batch_data_ids) + batch_data_ids=self._batch_data_ids, + ) self._batch_data_ids = [] def _read_batch_data(self): @@ -101,15 +106,14 @@ def _read_batch_data(self): assert self._file_list[ele.idx] == ele.path logger.info("begin process file {}:{}".format(ele.idx, ele.path)) - for record in self._splitter_cls(ele.path): - fields = record - + for fields in self._splitter_cls(ele.path): rec = data_server_pb2.Record() rec.record_no = fields[0] - assert isinstance(rec.record_no,int), \ - "first element of splitter_cls must be the record index of this file" + assert isinstance( + rec.record_no, int + ), "first element of splitter_cls must be the record index of this file" - #FIXME(gongwb) filter it + # FIXME(gongwb) filter it for field in fields[1:]: rec.field_data.append(field) batch_data.records.append(rec) @@ -131,18 +135,22 @@ def read_batch_data(self): self._client.reach_data_end( reader_leader_endpoint=self._reader_leader_endpoint, reader_name=self._name, - pod_id=self._trainer_env.pod_id) + pod_id=self._trainer_env.pod_id, + timeout=60, + ) -def generate(args) - log_file_name = "edl_data_generator_{}.log".format(os.getpid()) + +def generate(args): from edl.utils import log_utils + global logger - logger = log_utils.get_logger(log_level=20, log_file_name=log_file_name) + logger = log_utils.get_logger(log_level=20, log_file_name=args.loger_file_name) logger.info("args:{}".format(args)) try: generator = Generator(args) generator.read_batch_data() - except: + except Exception: import traceback - args.error_queue.put(traceback.format_exc()) \ No newline at end of file + + args.error_queue.put(traceback.format_exc()) diff --git a/python/edl/utils/data_server_client.py b/python/edl/utils/data_server_client.py index a05a25af..e0021255 100644 --- a/python/edl/utils/data_server_client.py +++ b/python/edl/utils/data_server_client.py @@ -126,7 +126,7 @@ def get_batch_data_meta( exceptions.deserialize(res.status) logger.debug( - "pod client get_balanced_batch_data meta:{}".format( + "pod client get_batch_data_meta:{}".format( pb_utils.batch_data_meta_response_to_string(res) ) ) diff --git a/python/edl/utils/exceptions.py b/python/edl/utils/exceptions.py index 981c516e..a223e46a 100644 --- a/python/edl/utils/exceptions.py +++ b/python/edl/utils/exceptions.py @@ -81,7 +81,7 @@ class EdlFileListNotMatchError(EdlException): pass -class EdlDataGenerateError(EdlException): +class EdlDataProcessError(EdlException): pass diff --git a/python/edl/utils/reader.py b/python/edl/utils/reader.py index a4c3aa49..e0a7b70b 100644 --- a/python/edl/utils/reader.py +++ b/python/edl/utils/reader.py @@ -21,23 +21,23 @@ class ReaderMeta(object): def __init__(self, name, pod_id, data_server_endpoint): - self._name = name - self._pod_id = pod_id - self._endpoint = data_server_endpoint + self.name = name + self.pod_id = pod_id + self.endpoint = data_server_endpoint def to_json(self): d = { - "name": self._name, - "pod_id": self._pod_id, - "endpoint": self._endpoint, + "name": self.name, + "pod_id": self.pod_id, + "endpoint": self.endpoint, } return json.dumps(d) def from_json(self, s): d = json.loads(s) - self._name = d["name"] - self._pod_id = d["pod_id"] - self._endpoint = d["endpoint"] + self.name = d["name"] + self.pod_id = d["pod_id"] + self.endpoint = d["endpoint"] def __str_(self): return self._to_json() From a6351d044f4d851220ff98a1dbaba7a01f1767ba Mon Sep 17 00:00:00 2001 From: gongweibao Date: Tue, 13 Oct 2020 03:38:59 +0000 Subject: [PATCH 9/9] add test=develop --- python/edl/collective/distribute_reader.py | 15 +-- python/edl/collective/state.py | 2 +- python/edl/tests/unittests/CMakeLists.txt | 3 +- .../tests/unittests/etcd_trainer_test_base.py | 37 +++++++ .../edl/tests/unittests/test_data_reader.py | 83 +++++++++++++--- python/edl/utils/batch_data_accesser.py | 2 +- python/edl/utils/batch_data_generator.py | 3 +- python/edl/utils/{env.py => job_env.py} | 53 ---------- python/edl/utils/trainer_env.py | 99 +++++++++++++++++++ 9 files changed, 219 insertions(+), 78 deletions(-) create mode 100644 python/edl/tests/unittests/etcd_trainer_test_base.py rename python/edl/utils/{env.py => job_env.py} (80%) create mode 100644 python/edl/utils/trainer_env.py diff --git a/python/edl/collective/distribute_reader.py b/python/edl/collective/distribute_reader.py index 24e31028..e0f87995 100644 --- a/python/edl/collective/distribute_reader.py +++ b/python/edl/collective/distribute_reader.py @@ -14,7 +14,7 @@ from __future__ import print_function import multiprocessing -from edl.uitls import reader as edl_reader +from edl.utils import reader as edl_reader from edl.utils import batch_data_generator from edl.utils import batch_data_accesser from edl.utils import exceptions @@ -68,7 +68,7 @@ def stop(self): def __exit__(self): self.stop() - def _check(self, proc, error_queue): + def _check_proc(self, proc, error_queue): if self.proc.is_alive(): return True @@ -79,8 +79,9 @@ def _check(self, proc, error_queue): if len(error_queue) > 0: raise exceptions.EdlDataProcessError(error_queue[0]) - else: - raise exceptions.EdlDataProcessError("process exit:{}".format(exitcode)) + return + + raise exceptions.EdlDataProcessError("process exit:{}".format(exitcode)) def _start_generator(self): args = batch_data_generator.Args() @@ -93,6 +94,7 @@ def _start_generator(self): args.out_queue = self._generater_out_queue args.error_queue = self._generater_error_queue args.loger_name = "{}_generator_{}.log".format(self._name, self._logger_no) + logger.debug("start generator args {}".format(args)) self._generator = multiprocessing.Process( target=batch_data_generator.generate, args=args @@ -108,6 +110,7 @@ def _start_accesser(self): args.out_queue = self._accesser_out_queue args.queue_size = self._cache_capcity args.loger_name = "{}_accesser_{}.log".format(self._name, self._logger_no) + logger.debug("start accesser args {}".format(args)) self._accesser = multiprocessing.Process( batch_data_accesser.generate, args=(args) @@ -120,10 +123,10 @@ def __iter__(self): self._logger_no += 1 while True: - if not self._check(self._accesser, self._accesser_error_queue): + if not self._check_proc(self._accesser, self._accesser_error_queue): break - if not self._check(self._generator, self._generater_error_queue): + if not self._check_proc(self._generator, self._generater_error_queue): break try: diff --git a/python/edl/collective/state.py b/python/edl/collective/state.py index 3c9d645e..f04606f4 100644 --- a/python/edl/collective/state.py +++ b/python/edl/collective/state.py @@ -21,7 +21,7 @@ from edl.utils import train_status as edl_train_status from edl.utils import unique_name from edl.utils import env as edl_env -from edl.utils import etcd_client +from edl.discovery import etcd_client class DataCheckpoint(json_serializable.Serializable): diff --git a/python/edl/tests/unittests/CMakeLists.txt b/python/edl/tests/unittests/CMakeLists.txt index 577c07f0..aa07deab 100644 --- a/python/edl/tests/unittests/CMakeLists.txt +++ b/python/edl/tests/unittests/CMakeLists.txt @@ -74,9 +74,8 @@ endfunction() file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") #FIXME string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -LIST(REMOVE_ITEM TEST_OPS test_data_reader) -LIST(REMOVE_ITEM TEST_OPS test_train) LIST(REMOVE_ITEM TEST_OPS test_launch) +LIST(REMOVE_ITEM TEST_OPS test_train) foreach(TEST_OP ${TEST_OPS}) bash_test_modules(${TEST_OP} START_BASH etcd_test.sh ENVS "PYTHON_EXECUTABLE=${PYTHON_EXECUTABLE}") endforeach(TEST_OP) diff --git a/python/edl/tests/unittests/etcd_trainer_test_base.py b/python/edl/tests/unittests/etcd_trainer_test_base.py new file mode 100644 index 00000000..e10efc49 --- /dev/null +++ b/python/edl/tests/unittests/etcd_trainer_test_base.py @@ -0,0 +1,37 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import edl.utils.constants as constants +import edl.utils.log_utils as log_utils +import os +import unittest +from edl.discovery.etcd_client import EtcdClient + +g_etcd_endpoints = "127.0.0.1:2379" + + +class EtcdTestBase(unittest.TestCase): + def setUp(self, job_id): + log_utils.get_logger(log_level=10) + self._etcd = EtcdClient([g_etcd_endpoints], root=job_id) + self._etcd.init() + + self._old_environ = copy.copy(dict(os.environ)) + constants.clean_etcd(self._etcd) + + def tearDown(self): + os.environ.clear() + os.environ.update(self._old_environ) + constants.clean_etcd(self._etcd) diff --git a/python/edl/tests/unittests/test_data_reader.py b/python/edl/tests/unittests/test_data_reader.py index 62e712cc..4a31ecec 100644 --- a/python/edl/tests/unittests/test_data_reader.py +++ b/python/edl/tests/unittests/test_data_reader.py @@ -13,34 +13,91 @@ # limitations under the License. import unittest -from edl.collective.data_reader import DistributedDataReader -from edl.collective.dataset import TxtFileSplitter +import os +from edl.collective import distribute_reader +from edl.collective import dataset +from edl.tests.unittests import etcd_trainer_test_base +from edl.collective import state as edl_state + + +class Args(object): + def __init__(self): + self.job_id = None + self.pod_id = None + self.global_rank = None + self.rank_in_pod = None + self.trainer_endpoints = None + self.pod_ids = None + self.gpu_id = "0" + + +class TestDataReader(etcd_trainer_test_base.EtcdTestBase): + def _init_args(self, pod_id, global_rank, rank_in_pod): + args = etcd_trainer_test_base.Args() + self.job_id = self._job_id + self.pod_id = str(pod_id) + self.global_rank = str(global_rank) + self.rank_in_pod = str(rank_in_pod) + self.trainer_endpoints = None + self.pod_ids = "0,1" + self.gpu_id = "0" + + return args + + def _update_env(self, pod_id, global_rank, rank_in_pod): + args = self._init_args(pod_id, global_rank, rank_in_pod) + proc_env = { + "PADDLE_JOB_ID": args.job_id, + "PADDLE_POD_ID": args.pod_id, + "EDL_POD_LEADER_ID": "0", + "PADDLE_ETCD_ENDPOINTS": "127.0.0.1:2379", + "PADDLE_TRAINER_ID": args.global_rank, + "PADDLE_TRAINER_RANK_IN_POD": args.rank_in_pod, + "EDL_POD_IDS": args.pod_ids, + "PADDLE_TRAINER_ENDPOINTS": args.trainer_endpoints, + "PADDLE_EDL_HDFS_HOME": "/usr/local/hadoop-2.7.7", + "PADDLE_EDL_HDFS_NAME": "", + "PADDLE_EDL_HDFS_UGI": "", + "PADDLE_EDL_HDFS_PATH": "hdfs://{}".format(args.job_id), + "PADDLE_EDL_ONLY_FOR_CE_TEST": "1", + "PADDLE_EDL_FS_CACHE": ".{}".format(args.job_id), + "PADDLE_EDL_SAVE_CHECKPOINT_INTER": "0", + "CUDA_VISIBLE_DEVICES": args.gpu_id, + } + os.environ.pop("https_proxy", None) + os.environ.pop("http_proxy", None) + os.environ.update(proc_env) -class TestDataReader(unittest.TestCase): def setUp(self): + self._job_id = "test_data_reader" + super(TestDataReader, self).setUp(self._job_id) + self._file_list = ["./data_server/a.txt", "./data_server/b.txt"] self._data = {} for idx, p in enumerate(self._file_list): - s = TxtFileSplitter(p) - for r in s: + reader = dataset.TxtFileSplitter(p) + for rec in reader: if idx not in self._data: self._data[idx] = [] - d = ((p), (r[0], r[1:])) - self._data[idx].append(d) # [(path),(rec_no, splitted_fiels)]... + self._data[idx].append(rec) def test_data_reader(self): - reader1 = DistributedDataReader( + self._update_env(pod_id="0", global_rank=0, rank_in_pod=0) + state = edl_state.PaddleState(total_batch_size=1) + reader1 = distribute_reader.Reader( + state=state, file_list=self._file_list, - file_splitter_cls=TxtFileSplitter, - splitted_data_field=["line"], + file_splitter_cls=dataset.TxtFileSplitter, batch_size=1, ) - reader2 = DistributedDataReader( + self._update_env(pod_id="1", global_rank=1, rank_in_pod=0) + state = edl_state.PaddleState(total_batch_size=1) + reader2 = distribute_reader.Reader( + state=state, file_list=self._file_list, - file_splitter_cls=TxtFileSplitter, - splitted_data_field=["line"], + file_splitter_cls=dataset.TxtFileSplitter, batch_size=1, ) diff --git a/python/edl/utils/batch_data_accesser.py b/python/edl/utils/batch_data_accesser.py index ed307ad3..e794ca55 100644 --- a/python/edl/utils/batch_data_accesser.py +++ b/python/edl/utils/batch_data_accesser.py @@ -15,7 +15,7 @@ import threading -from edl.uitls import reader as edl_reader +from edl.utils import reader as edl_reader from edl.utils import data_server from edl.utils import data_server_client from edl.utils import etcd_db diff --git a/python/edl/utils/batch_data_generator.py b/python/edl/utils/batch_data_generator.py index 75427ae9..475979a2 100644 --- a/python/edl/utils/batch_data_generator.py +++ b/python/edl/utils/batch_data_generator.py @@ -14,7 +14,6 @@ from __future__ import print_function from edl.utils import data_server_client from edl.utils import data_server_pb2 -from edl.utils import edl_process logger = None @@ -32,7 +31,7 @@ def __init__(self): self.error_queue = None -class Generator(edl_process.ProcessWrapper): +class Generator(object): """ 1. get file_list from data_server_leader 2. parse files of file_list and put BatchData to out_quque diff --git a/python/edl/utils/env.py b/python/edl/utils/job_env.py similarity index 80% rename from python/edl/utils/env.py rename to python/edl/utils/job_env.py index 0960affc..642e9a9e 100644 --- a/python/edl/utils/env.py +++ b/python/edl/utils/job_env.py @@ -174,56 +174,3 @@ def __str__(self): for k, v in six.iteritems(vars(self)): s += "{}:{} ".format(k, v) return s - - -class TrainerEnv(object): - """ - Parse all envs when edl_launch starts a trainer. - """ - - def __init__(self, args=None): - self._job_id = os.environ["PADDLE_JOB_ID"] - self._pod_id = os.environ["PADDLE_POD_ID"] - self._pod_leader_id = os.environ["EDL_POD_LEADER_ID"] - self._etcd_endpoints = os.environ["PADDLE_ETCD_ENDPOINTS"] - - self._global_rank = int(os.environ["PADDLE_TRAINER_ID"]) - self._rank_in_pod = int(os.environ["PADDLE_TRAINER_RANK_IN_POD"]) - self._trainer_endpoints = os.environ["PADDLE_TRAINER_ENDPOINTS"] - self._pod_ids = os.environ["EDL_POD_IDS"].split(",") - - @property - def pod_leader_id(self): - return self._pod_leader_id - - @property - def pod_ids(self): - return self._pod_ids - - @property - def pod_id(self): - return self._pod_id - - @property - def global_rank(self): - return self._global_rank - - @property - def rank_in_pod(self): - return self._rank_in_pod - - @property - def trainer_endpoints(self): - return self._trainer_endpoints - - @property - def size(self): - return len(self._trainer_endpoints) - - @property - def job_id(self): - return self._job_id - - @property - def etcd_endpoints(self): - return self._etcd_endpoints diff --git a/python/edl/utils/trainer_env.py b/python/edl/utils/trainer_env.py new file mode 100644 index 00000000..80baa3b1 --- /dev/null +++ b/python/edl/utils/trainer_env.py @@ -0,0 +1,99 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from edl.utils import job_env + + +class TrainerEnv(object): + """ + Parse all envs when edl_launch starts a trainer. + """ + + def __init__(self, args=None): + self._job_id = os.environ["PADDLE_JOB_ID"] + self._pod_id = os.environ["PADDLE_POD_ID"] + self._pod_leader_id = os.environ["EDL_POD_LEADER_ID"] + self._etcd_endpoints = os.environ["PADDLE_ETCD_ENDPOINTS"] + + self._global_rank = int(os.environ["PADDLE_TRAINER_ID"]) + self._rank_in_pod = int(os.environ["PADDLE_TRAINER_RANK_IN_POD"]) + self._trainer_endpoints = os.environ["PADDLE_TRAINER_ENDPOINTS"] + self._pod_ids = os.environ["EDL_POD_IDS"].split(",") + self._ce_test = int(os.getenv("PADDLE_EDL_ONLY_FOR_CE_TEST", "0")) + self._get_hdfs(args) + + def _get_hdfs(self, args): + # hdfs + self._hdfs_home = job_env.get_from_dict_or_env( + args, "hdfs_home", "PADDLE_EDL_HDFS_HOME" + ) + self._hdfs_name = job_env.get_from_dict_or_env( + args, "hdfs_name", "PADDLE_EDL_HDFS_NAME" + ) + self._hdfs_path = job_env.get_from_dict_or_env( + args, "hdfs_path", "PADDLE_EDL_HDFS_PATH" + ) + self._hdfs_ugi = job_env.get_from_dict_or_env( + args, "hdfs_ugi", "PADDLE_EDL_HDFS_UGI" + ) + + # assert hdfs value + if not self._ce_test: + assert ( + len(self._hdfs_home) > 3 + and len(self._hdfs_name) > 6 + and len(self._hdfs_ugi) > 3 + and len(self._hdfs_path) > 0 + ), "hdfs environ must set" + else: + assert ( + len(self._hdfs_home) > 3 and len(self._hdfs_path) > 0 + ), "hdfs environ must set" + + @property + def pod_leader_id(self): + return self._pod_leader_id + + @property + def pod_ids(self): + return self._pod_ids + + @property + def pod_id(self): + return self._pod_id + + @property + def global_rank(self): + return self._global_rank + + @property + def rank_in_pod(self): + return self._rank_in_pod + + @property + def trainer_endpoints(self): + return self._trainer_endpoints + + @property + def size(self): + return len(self._trainer_endpoints) + + @property + def job_id(self): + return self._job_id + + @property + def etcd_endpoints(self): + return self._etcd_endpoints