From c2ec0ddaa9fd329d47e958bccd2696141e1fd6ba Mon Sep 17 00:00:00 2001 From: qwz Date: Fri, 28 Jul 2023 10:07:32 +0800 Subject: [PATCH 1/8] fix wrong dataset file path --- RepoCoder/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/RepoCoder/utils.py b/RepoCoder/utils.py index af2c24a..6233a3d 100644 --- a/RepoCoder/utils.py +++ b/RepoCoder/utils.py @@ -20,11 +20,11 @@ class CONSTANTS: rgrg = 'r-g-r-g' # RepoCoder, two-stage retrieval and generation class FilePathBuilder: - api_completion_benchmark = 'datasets/random-api-completion.test.jsonl' - random_line_completion_benchmark = 'datasets/random-line-completion.test.jsonl' + api_completion_benchmark = 'datasets/api_level_completion_2k_context_codex.test.jsonl' + random_line_completion_benchmark = 'datasets/line_level_completion_2k_context_codex.test.jsonl' # short version for codegen - short_api_completion_benchmark = 'datasets/random-api-completion-short-version.test.jsonl' - short_random_line_completion_benchmark = 'datasets/random-line-completion-short-version.test.jsonl' + short_api_completion_benchmark = 'datasets/api_level_completion_1k_context_codegen.test.jsonl' + short_random_line_completion_benchmark = 'datasets/line_level_completion_1k_context_codegen.test.jsonl' repo_base_dir = 'repositories/line_and_api_level' @staticmethod From cecce36b082e0e78c0c4097a7cb4777e5e7d5f28 Mon Sep 17 00:00:00 2001 From: Fengji Zhang Date: Fri, 3 May 2024 11:27:23 +0800 Subject: [PATCH 2/8] update codegen inference script --- RepoCoder/README.md | 1 + RepoCoder/codegen_inference.py | 77 ++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 RepoCoder/codegen_inference.py diff --git a/RepoCoder/README.md b/RepoCoder/README.md index 84c05b6..079b893 100644 --- a/RepoCoder/README.md +++ b/RepoCoder/README.md @@ -22,6 +22,7 @@ This project contains the basic components of RepoCoder. Here is an overview: |-- build_prompt.py # build the prompt with the unfinished code and the retrieved code snippets |-- run_pipeline.py # run the code completion pipeline |-- compute_score.py # evaluate the performance of the code completion +|-- codegen_inference.py # an example script for using CodeGen to generate code completions |-- utils.py # utility functions |-- datasets/datasets.zip # the input data for the code completion task |-- function_level_completion_4k_context_codex.test.jsonl diff --git a/RepoCoder/codegen_inference.py b/RepoCoder/codegen_inference.py new file mode 100644 index 0000000..5ff3519 --- /dev/null +++ b/RepoCoder/codegen_inference.py @@ -0,0 +1,77 @@ +import torch +import tqdm +import json +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class Tools: + @staticmethod + def load_jsonl(path): + with open(path, 'r') as f: + return [json.loads(line) for line in f.readlines()] + + @staticmethod + def dump_jsonl(obj, path): + with open(path, 'w') as f: + for line in obj: + f.write(json.dumps(line) + '\n') + + +class CodeGen: + def __init__(self, model_name, batch_size): + self.model_name = model_name + self.model = AutoModelForCausalLM.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + self.tokenizer.add_special_tokens({'pad_token': self.tokenizer.eos_token}) + self.model.cuda() + self.batch_size = batch_size + print('done loading model') + + def _get_batchs(self, prompts, batch_size): + batches = [] + for i in range(0, len(prompts), batch_size): + batches.append(prompts[i:i+batch_size]) + return batches + + def _generate_batch(self, prompt_batch, max_new_tokens=100): + prompts = self.tokenizer(prompt_batch, return_tensors='pt', padding=True, truncation=True) + + with torch.no_grad(): + gen_tokens = self.model.generate( + input_ids = prompts['input_ids'].cuda(), + attention_mask = prompts['attention_mask'].cuda(), + do_sample=False, + max_new_tokens=max_new_tokens, + ) + gen_text = self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True) + for i in range(len(gen_text)): # remove the prompt + gen_text[i] = gen_text[i][len(prompt_batch[i]):] + return gen_text + + def batch_generate(self, file): + print(f'generating from {file}') + lines = Tools.load_jsonl(file) + # have a new line at the end + prompts = [f"{line['prompt']}\n" for line in lines] + batches = self._get_batchs(prompts, self.batch_size) + gen_text = [] + for batch in tqdm.tqdm(batches): + gen_text.extend(self._generate_batch(batch)) + print(f'generated {len(gen_text)} samples') + assert len(gen_text) == len(prompts) + new_lines = [] + for line, gen in zip(lines, gen_text): + new_lines.append({ + 'prompt': line['prompt'], + 'metadata': line['metadata'], + 'choices': [{'text': gen}] + }) + Tools.dump_jsonl(new_lines, file.replace('.jsonl', f'_{self.model_name.split("/")[-1]}.jsonl')) + + +if __name__ == '__main__': + file_path = 'datasets/line_level_completion_1k_context_codegen.test.jsonl' + tiny_codegen = 'Salesforce/codegen-350M-mono' + + cg = CodeGen(tiny_codegen, batch_size=8) + cg.batch_generate(file_path) From 6a6ef6359a3587e134d8350c3eebb0e639e7789a Mon Sep 17 00:00:00 2001 From: Fengji Zhang Date: Fri, 3 May 2024 11:37:01 +0800 Subject: [PATCH 3/8] fix missing repo vectors --- RepoCoder/run_pipeline.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/RepoCoder/run_pipeline.py b/RepoCoder/run_pipeline.py index 6592ce1..e44e29e 100644 --- a/RepoCoder/run_pipeline.py +++ b/RepoCoder/run_pipeline.py @@ -12,13 +12,12 @@ from utils import CONSTANTS, CodexTokenizer def make_repo_window(repos, window_sizes, slice_sizes): - worker = MakeWindowWrapper(None, repos, window_sizes, slice_sizes) - worker.window_for_repo_files() + MakeWindowWrapper(None, repos, window_sizes, slice_sizes).window_for_repo_files() + vectorizer = BagOfWords + BuildVectorWrapper(None, vectorizer, repos, window_sizes, slice_sizes).vectorize_repo_windows() def run_RG1_and_oracle_method(benchmark, repos, window_sizes, slice_sizes): - # build code snippets for all the repositories - make_repo_window(repos, window_sizes, slice_sizes) # build code snippets for vanilla retrieval-augmented approach and ground truth MakeWindowWrapper(benchmark, repos, window_sizes, slice_sizes).window_for_baseline_and_ground() # build vector for vanilla retrieval-augmented approach and ground truth @@ -62,6 +61,9 @@ def run_RepoCoder_method(benchmark, repos, window_sizes, slice_sizes, prediction window_sizes = [20] slice_sizes = [2] # 20 / 2 = 10 + # build window for the repos + make_repo_window(repos, window_sizes, slice_sizes) + # build prompt for the RG1 and oracle methods run_RG1_and_oracle_method(CONSTANTS.api_benchmark, repos, window_sizes, slice_sizes) From 227e96e93e1538aea3cb408667f7d10c111d0811 Mon Sep 17 00:00:00 2001 From: Fengji Zhang Date: Mon, 20 May 2024 10:34:30 +0800 Subject: [PATCH 4/8] fix wrong constant name --- RepoCoder/make_window.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RepoCoder/make_window.py b/RepoCoder/make_window.py index 0da90b6..e991ad5 100644 --- a/RepoCoder/make_window.py +++ b/RepoCoder/make_window.py @@ -138,7 +138,7 @@ def build_window(self): } }) print(f'build {len(code_windows)} ground truth windows for {self.repo} with window size {self.window_size}') - output_path = FilePathBuilder.search_first_window_path(self.benchmark, CONSTANTS.rg, self.repo, self.window_size) + output_path = FilePathBuilder.search_first_window_path(self.benchmark, CONSTANTS.gt, self.repo, self.window_size) Tools.dump_pickle(code_windows, output_path) class PredictionWindowMaker: From f7e8e6d729e358e4e3031521d7d96ea9bfecd524 Mon Sep 17 00:00:00 2001 From: "fengji.zhang" Date: Sun, 11 Aug 2024 18:16:31 +0800 Subject: [PATCH 5/8] scripts for constructing API and line completion datasets --- RepoCoder/make_dataset/api_benchmark.py | 246 +++++++++++++++++ RepoCoder/make_dataset/ast_visitors.py | 268 +++++++++++++++++++ RepoCoder/make_dataset/config.py | 15 ++ RepoCoder/make_dataset/file_visitors.py | 224 ++++++++++++++++ RepoCoder/make_dataset/make_dataset_utils.py | 103 +++++++ RepoCoder/make_dataset/random_benchmark.py | 122 +++++++++ 6 files changed, 978 insertions(+) create mode 100644 RepoCoder/make_dataset/api_benchmark.py create mode 100644 RepoCoder/make_dataset/ast_visitors.py create mode 100644 RepoCoder/make_dataset/config.py create mode 100644 RepoCoder/make_dataset/file_visitors.py create mode 100644 RepoCoder/make_dataset/make_dataset_utils.py create mode 100644 RepoCoder/make_dataset/random_benchmark.py diff --git a/RepoCoder/make_dataset/api_benchmark.py b/RepoCoder/make_dataset/api_benchmark.py new file mode 100644 index 0000000..21bd0ad --- /dev/null +++ b/RepoCoder/make_dataset/api_benchmark.py @@ -0,0 +1,246 @@ +import os +import ipdb +import random +from tqdm import tqdm +from collections import defaultdict +from concurrent.futures import as_completed, ProcessPoolExecutor + +from file_visitors import FileDefinedAPI, FileImportedAPI, FileCallAPI +from make_dataset_utils import Tools, CodexTokenizer + + +class APICallLocator: + def __init__(self, base_dir, repo): + self.base_dir = base_dir + self.repo = repo + self.source_code_files = Tools.iterate_repository(base_dir, repo) + + def collect_defined_apis_for_each_file(self): + file_define_api = FileDefinedAPI(self.repo, self.source_code_files) + defined_apis_by_file = file_define_api.get_defined_apis_by_file() + + init_files = dict() + for fpath_tuple in self.source_code_files.keys(): + if fpath_tuple[-1] == '__init__.py': + init_files[fpath_tuple] = self.source_code_files[fpath_tuple] + file_import_api = FileImportedAPI(self.repo, init_files, defined_apis_by_file) + imported_apis_of_init_files = file_import_api.get_imported_apis_by_file() + for module_path_tuple, imported_apis_info in imported_apis_of_init_files.items(): + defined_apis_info = defined_apis_by_file[module_path_tuple] + defined_apis_by_file[module_path_tuple] = {**defined_apis_info, **imported_apis_info} + return defined_apis_by_file + + def collect_available_apis_for_each_file(self): + available_apis_by_file = self.collect_defined_apis_for_each_file() + non_init_files = dict() + for fpath_tuple in self.source_code_files.keys(): + if fpath_tuple[-1] != '__init__.py': + non_init_files[fpath_tuple] = self.source_code_files[fpath_tuple] + file_import_api = FileImportedAPI(self.repo, non_init_files, available_apis_by_file) + imported_apis_of_non_init_files = file_import_api.get_imported_apis_by_file() + for module_path_tuple, imported_apis_info in imported_apis_of_non_init_files.items(): + defined_apis_info = available_apis_by_file[module_path_tuple] + available_apis_by_file[module_path_tuple] = {**defined_apis_info, **imported_apis_info} + return available_apis_by_file + + def _build_func_signature_context_with_positions(self, base_dir, fpath_tuple, func_header_start_line_no, func_body_start_line_no, class_name): + file_path = os.path.join(base_dir, *fpath_tuple) + code = Tools.read_code(file_path) + func_signature_and_doc = code.splitlines()[func_header_start_line_no-1:func_body_start_line_no-1] # lineno is 1-indexed + intent = 0 + if not func_signature_and_doc: + ipdb.set_trace() + for i in func_signature_and_doc[0]: + if i == ' ': intent += 1 + else: break + func_signature_and_doc = [i[intent:] for i in func_signature_and_doc] + if class_name: + func_signature_and_doc = [f'class {class_name}:'] + func_signature_and_doc + return '\n'.join(func_signature_and_doc) + + def _build_func_body_context_with_positions(self, base_dir, fpath_tuple, func_start_line_no, func_end_line_no, class_name): + file_path = os.path.join(base_dir, *fpath_tuple) + code = Tools.read_code(file_path) + func_body = code.splitlines()[func_start_line_no-1:func_end_line_no] # lineno is 1-indexed + intent = 0 + if not func_body: + ipdb.set_trace() + for i in func_body[0]: + if i == ' ': intent += 1 + else: break + func_body = [i[intent:] for i in func_body] + if class_name: + func_body = [f'class {class_name}:'] + func_body + return '\n'.join(func_body) + + def _build_api_set_for_available_api_dicts(self, available_apis_by_file): + def __buil_context_for_available_api(available_api): + try: + func_header_start_line_no = available_api['func_node_start_end_positions']['start_lineno'] + func_end_line_no = available_api['func_node_start_end_positions']['end_lineno'] + func_body_start_line_no = available_api['func_body_start_end_positions']['start_lineno'] if available_api['func_body_start_end_positions'] else func_end_line_no + fpath_tuple = available_api['current_fpath_tuple'] + class_name = available_api['class_name'] if 'class_name' in available_api else None + func_signature_context = self._build_func_signature_context_with_positions(self.base_dir, fpath_tuple, func_header_start_line_no, func_body_start_line_no, class_name) + func_body_context = self._build_func_body_context_with_positions(self.base_dir, fpath_tuple, func_header_start_line_no, func_end_line_no, class_name) + except Exception as e: + print(e) + ipdb.set_trace() + + return (available_api['api_name'], func_signature_context, func_body_context) + + available_api_set_by_file = defaultdict(set) + for fpath_tuple in available_apis_by_file.keys(): + # imported apis, imported classes, imported modules, imported members + outer_apis = set() + outer_apis |= set([__buil_context_for_available_api(i) for i in available_apis_by_file[fpath_tuple]['imported_outer_apis']]) + + class_apis = set() + for class_info in available_apis_by_file[fpath_tuple]['imported_classes']: + class_name = class_info['class_name'] + located_module_path_tuple = class_info['located_module_path_tuple'] + class_apis |= set([ + __buil_context_for_available_api(i) for i in + available_apis_by_file[located_module_path_tuple]['defined_classes'][class_name] + ]) + + # TODO: cannot find the original position of the imported members from __init__ + # members = set([i['member_name'] for i in available_apis_by_file[fpath_tuple]['imported_members']]) + available_api_set_by_file[fpath_tuple] = outer_apis | class_apis + + for fpath_tuple in available_apis_by_file.keys(): + module_apis = set() + imported_modules = [i['located_module_path_tuple'] for i in available_apis_by_file[fpath_tuple]['imported_modules']] + for imported_module_path_tuple in imported_modules: + module_apis |= available_api_set_by_file[imported_module_path_tuple] + available_api_set_by_file[fpath_tuple] |= module_apis + + return available_api_set_by_file + + def find_intra_api_calls_for_each_file(self): + available_apis_by_file = self.collect_available_apis_for_each_file() + available_api_set_by_file = self._build_api_set_for_available_api_dicts(available_apis_by_file) + file_call_api = FileCallAPI(self.repo, self.source_code_files) + called_apis_by_file = file_call_api.get_called_apis_by_file() + for fpath_tuple, called_apis_info in called_apis_by_file.items(): + available_api_set = available_api_set_by_file[fpath_tuple] + called_intra_apis = [] + for called_api in called_apis_info: + for available_api in available_api_set: + if called_api['api_name'] == available_api[0]: + called_api['signature_context'] = available_api[1] + called_api['body_context'] = available_api[2] + called_intra_apis.append(called_api) + break + called_apis_by_file[fpath_tuple] = called_intra_apis + return called_apis_by_file + + +class APIHoleDigger: + def __init__(self, repo_base_dir, cache_base_dir, repo, context_max_tokens=2000): + self.repo_base_dir = repo_base_dir + self.repo = repo + self.chosen_apis_cache_path = os.path.join(cache_base_dir, f'{self.repo}-random-api-200.pkl') + self.context_max_tokens = context_max_tokens + self.tokenizer = CodexTokenizer() + + def _make_context_prompt_by_prepending(self, base_dir, fpath_tuple, called_line_no, additional_context, context_max_tokens): + # line_no is 0-indexed + code = Tools.read_code(os.path.join(base_dir, *fpath_tuple)) + previous_code_lines = code.splitlines()[:called_line_no] + if not previous_code_lines: + ipdb.set_trace() + additional_lines = [] + if additional_context: + additional_lines = ["'''Relevant Helpful functions:"] + additional_context.splitlines() + ["'''"] + trimed_context, context_start_lineno = Tools.trim_context(self.tokenizer, previous_code_lines, context_max_tokens) + context_lines = additional_lines + trimed_context + return '\n'.join(context_lines), context_start_lineno + + def _get_api_call_ground_truth(self, base_dir, fpath_tuple, start_line_no, end_line_no): + code = Tools.read_code(os.path.join(base_dir, *fpath_tuple)) + code_lines = code.splitlines() + return '\n'.join(code_lines[start_line_no:end_line_no+1]) + + def _dig_hole(self, called_api, context_type): + fpath_tuple = called_api['current_fpath_tuple'] + called_line_no = called_api['api_call_node_start_end_positions']['start_lineno'] - 1 + if context_type == 'none': + additional_context = '' + elif context_type == 'signature': + additional_context = called_api['signature_context'] + elif context_type == 'body': + additional_context = called_api['body_context'] + context_prompt, context_start_lineno = self._make_context_prompt_by_prepending(self.repo_base_dir, fpath_tuple, called_line_no, additional_context, self.context_max_tokens) + called_end_line_no = called_api['api_call_node_start_end_positions']['end_lineno'] - 1 + ground_truth = self._get_api_call_ground_truth(self.repo_base_dir, fpath_tuple, called_line_no, called_end_line_no) + return context_prompt, context_start_lineno, ground_truth, fpath_tuple, called_line_no + + def dig_holes(self, context_type): + chosen_apis = Tools.load_pickle(self.chosen_apis_cache_path) + prompts = [] + print(f'digging holes for {self.repo}...') + with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor: + future_to_prompt = {executor.submit(self._dig_hole, api, context_type): index for index, api in enumerate(chosen_apis)} + for future in tqdm(as_completed(future_to_prompt), total=len(future_to_prompt)): + index = future_to_prompt[future] + prompt = future.result() + prompts.append((prompt, index)) + prompts = sorted(prompts, key=lambda x: x[1]) + return [i[0] for i in prompts] + + def random_chosen(self, api_call_locator, num=200): + if os.path.exists(self.chosen_apis_cache_path): + return + called_apis_by_file = api_call_locator.find_intra_api_calls_for_each_file() + all_called_apis = [i for apis in list(called_apis_by_file.values()) for i in apis] + random.shuffle(all_called_apis) + Tools.dump_pickle(all_called_apis[:num], self.chosen_apis_cache_path) + + +def build_random_API_benchmark(): + REPO_BASE_DIR = 'downloaded_repos' + OUT_BASE_DIR = 'output' + CACHE_BASE_DIR = 'cache' + repos = [ + 'huggingface_diffusers', + 'nerfstudio-project_nerfstudio', + 'awslabs_fortuna', + 'huggingface_evaluate', + 'google_vizier', + 'alibaba_FederatedScope', + 'pytorch_rl', + 'opendilab_ACE' + ] + holedigger_by_repo = dict() + for repo in repos: + locator = APICallLocator(REPO_BASE_DIR, repo) + holedigger = APIHoleDigger(REPO_BASE_DIR, CACHE_BASE_DIR, repo, context_max_tokens=2000) + holedigger.random_chosen(locator) + holedigger_by_repo[repo] = holedigger + + for context_type in ['none']: + prompts_by_repo = dict() + for repo in repos: + holedigger = holedigger_by_repo[repo] + prompts = holedigger.dig_holes(context_type) + prompts_by_repo[repo] = prompts + json_lines = [] + for repo, prompts in prompts_by_repo.items(): + json_lines.extend([ + { + 'prompt': prompt, + 'metadata': { + 'task_id': f'{repo}/{index}', + 'ground_truth': ground_truth, + 'fpath_tuple': fpath_tuple, + 'context_start_lineno': context_start_lineno, + 'line_no': end_line_no + } + } + for index, (prompt, context_start_lineno, ground_truth, fpath_tuple, end_line_no) in enumerate(prompts) + ]) + Tools.dump_jsonl(json_lines, os.path.join(OUT_BASE_DIR, f'random-api-completion.jsonl')) + +if __name__ == '__main__': + build_random_API_benchmark() \ No newline at end of file diff --git a/RepoCoder/make_dataset/ast_visitors.py b/RepoCoder/make_dataset/ast_visitors.py new file mode 100644 index 0000000..033b923 --- /dev/null +++ b/RepoCoder/make_dataset/ast_visitors.py @@ -0,0 +1,268 @@ +""" +parse Python files to find: +1. the start and end locations of called apis +2. the definations of all apis and classes that can be called +""" +import ast +import astunparse +from collections import defaultdict +import ipdb + +class APICallVisitor(ast.NodeVisitor): + def __init__(self, fpath_tuple): + super().__init__() + # stores all the called apis, including intra-project and public ones + self.called_apis = list() + self.fpath_tuple = fpath_tuple + + def visit_Call(self, node: ast.AST): + # start/ending positions for the api call(including api_prefix, api_name, and arguments) + start_lineno = node.lineno + start_col = node.col_offset + end_lineno = node.end_lineno + end_col = node.end_col_offset + # lineno starts from 1, col_offset starts from 0, and end_offset contains the last symbol + # to extract the segment, use line[start_col:end_col] + start_end_positions = { + 'start_lineno': start_lineno, + 'start_col': start_col, + 'end_lineno': end_lineno, + 'end_col': end_col + } + self.generic_visit(node) + + func = node.func + try: + if self._is_getattr_call(node): + # for example "getattr(mylib.a.b, 'const')(x, y)" + module = node.func.args[0] # mylib.a.b + module = astunparse.unparse(module).strip() + func = node.func.args[1] # 'const' + if isinstance(func, ast.Constant): + api_name = func.value + else: # for example "getattr(mylib.a.b, identifier)(x, y)" + return + elif isinstance(func, ast.Attribute) or isinstance(func, ast.Name): # "x, y = W.func(a, b, c)" + module, api_name = self._get_func_name_and_module_from_attribute_or_name(func) + elif isinstance(func, ast.Subscript): # for example "x = W.m[0]()" + # module, api_name = self._get_func_name_and_module_from_attribute_or_name(func.value) + # api_name += astunparse.unparse(func).strip()[len(astunparse.unparse(func.value))-1:] + return + elif isinstance(func, ast.Call): # for example "x = W.m()()" + # module, api_name = self._get_func_name_and_module_from_attribute_or_name(func.func) + # api_name += astunparse.unparse(func).strip()[len(astunparse.unparse(func.func))-1:] + return + elif isinstance(func, ast.IfExp): # for example "(x if None else y)()" + # module, api_name = '', astunparse.unparse(func).strip() + return + elif isinstance(func, ast.Lambda): # for example "lambda: x()" + return + elif isinstance(func, ast.BinOp): # for example "(ctypes.c_int64 * self._output_ndims[i])()" + return + elif isinstance(func, ast.BoolOp): # for example "(_load or (lambda v: v))(value_)" + return + elif func.id == 'getattr': # don't need to handle getattr() because it is handled in _is_getattr_call() + return + self.called_apis.append({ + 'api_name': api_name, + 'api_call_prefix': module, + 'api_call_node_start_end_positions': start_end_positions, + 'current_fpath_tuple': self.fpath_tuple + }) + except Exception as e: + print(e) + print(astunparse.unparse(node)) + ipdb.set_trace() + + + def _get_func_name_and_module_from_attribute_or_name(self, node: ast.AST): + if isinstance(node, ast.Attribute): + module = astunparse.unparse(node.value).strip() + api_name = node.attr + return module, api_name + elif isinstance(node, ast.Name): + return '', node.id + + + def _is_getattr_call(self, base_node: ast.AST) -> bool: + """ + finds the pattern getattr(mylib, 'const')() + """ + if not isinstance(base_node, ast.Call): + return False + node = base_node.func + if not isinstance(node, ast.Call): + return False + if not (isinstance(node.func, ast.Name) and node.func.id == "getattr"): + return False + return True + + +class APIImportVisitor(ast.NodeVisitor): + def __init__(self, file_module, fpath_tuple): + super().__init__() + self.file_module = file_module + self.fpath_tuple = fpath_tuple + self.renamed_api = dict() # alias of import and func + self.imported_apis = [] + + def visit_Import(self, node: ast.AST): + ''' + for example "import numpy.array as arr": + "numpy" is the module stored as value in the "api_path", + "array" is the name stored as value in the "remapped", + "arr" is the alias stored as key for the "remapped" and "api_path" + and put "arr" into the "imported_apis" + ''' + self.generic_visit(node) + for n in node.names: + api_name = n.name + api_as_name = '' + module = '' + if '.' in api_name: # for example "import numpy.array" + api_name = n.name.split('.')[-1] + module = '.'.join(n.name.split('.')[:-1]) + if n.asname: + api_as_name = n.asname + self.renamed_api[api_as_name] = api_name + self.imported_apis.append({ + 'api_name': api_name, + 'api_path': module, + 'api_as_name': api_as_name, + 'current_fpath_tuple': self.fpath_tuple # for calculating the similarity between two packages + }) + + def visit_ImportFrom(self, node: ast.AST): + ''' + for example "from numpy import array as arr": + "numpy" is the module stored as value in the "api_path", + "array" is the name stored as value in the "remapped", + "arr" is the alias stored as key for the "remapped" and "api_path" + and put "arr" into the "imported_apis" + ''' + self.generic_visit(node) + module = node.module if node.module else '' + api_as_name = '' + if node.level: # relative import, rebuild module + if not module and node.level == 1: # "from . import a" means import from __init__.py + return + file_module = self.file_module + if self.fpath_tuple[-1] == '__init__.py': + file_module += '.__init__' # fix the module level of __init__ when doing relative import + new_module_chain = file_module.split('.')[:-node.level] + [module] + module = '.'.join([i for i in new_module_chain if i]) # in case module is empty + + for n in node.names: + api_name = n.name + if n.asname: + api_as_name = n.asname + self.renamed_api[api_as_name] = api_name + self.imported_apis.append({ + 'api_name': api_name, + 'api_path': module, + 'api_as_name': api_as_name, + 'current_fpath_tuple': self.fpath_tuple + }) + + +class APIDefineVisitor(ast.NodeVisitor): + def __init__(self, fpath_tuple): + super().__init__() + self.defined_outer_apis = [] + self.defined_classes = defaultdict(list) + self.fpath_tuple = fpath_tuple + + def store_parent_node(self, root): + ''' + Remember to first run this function before calling visit + ''' + for node in ast.walk(root): # recursive visit + for child in ast.iter_child_nodes(node): + child.parent = node + + def _get_positon(self, node): + start_lineno = node.lineno + start_col = node.col_offset + end_lineno = node.end_lineno + end_col = node.end_col_offset + """ + lineno starts from 1, col_offset starts from 0, and end_offset contains the last symbol + to extract the segment, use line[start_col:end_col] + """ + start_end_positions = { + 'start_lineno': start_lineno, + 'start_col': start_col, + 'end_lineno': end_lineno, + 'end_col': end_col + } + return start_end_positions + + def _build_api_path(self, node): + api_path = [] + current_node = node + while hasattr(current_node, 'parent'): + current_node = current_node.parent + if isinstance(current_node, ast.ClassDef): + api_path.insert(0, ('class', current_node.name)) + elif isinstance(current_node, ast.Module): + break + return api_path + + def _get_func_type(self, node): + """ + tell whether the function is a class method or an outer method or a local method + if it is a class method, include the class name + if it is a local method, return None + """ + current_node = node + parent_nodes = [] + while hasattr(current_node, 'parent'): + current_node = current_node.parent + if isinstance(current_node, ast.FunctionDef): + parent_nodes.append(('func', current_node.name)) + elif isinstance(current_node, ast.ClassDef): + parent_nodes.append(('class', current_node.name)) + + if len(parent_nodes) > 1: # local method, cannot be called by other module + return ('local', None) + elif len(parent_nodes) < 1: + return ('outer', None) + elif parent_nodes[0][0] == 'func': # local method + return ('local', None) + elif parent_nodes[0][0] == 'class': # class method + return ('class', parent_nodes[0][1]) + else: + return ('outer', None) + + def visit_FunctionDef(self, node): + self.generic_visit(node) + node_type, class_name = self._get_func_type(node) + if node_type == 'local': + return + docstring = None + body_index = 0 + if node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Constant): + docstring = node.body[0].value.value + body_index += 1 + func_node_start_end_positions = self._get_positon(node) + func_doc_start_end_positions = self._get_positon(node.body[0]) if docstring else None + func_body_start_end_positions = self._get_positon(node.body[body_index]) if len(node.body) > body_index else None + + if node_type == 'outer': + self.defined_outer_apis.append({ + 'api_name': node.name, + 'func_node_start_end_positions': func_node_start_end_positions, + 'func_doc_start_end_positions': func_doc_start_end_positions, + 'func_body_start_end_positions': func_body_start_end_positions, + 'current_fpath_tuple': self.fpath_tuple + }) + elif node_type == 'class': + assert class_name + self.defined_classes[class_name].append({ + 'api_name': node.name if node.name != '__init__' else class_name, + 'class_name': class_name, + 'func_node_start_end_positions': func_node_start_end_positions, + 'func_doc_start_end_positions': func_doc_start_end_positions, + 'func_body_start_end_positions': func_body_start_end_positions, + 'current_fpath_tuple': self.fpath_tuple + }) diff --git a/RepoCoder/make_dataset/config.py b/RepoCoder/make_dataset/config.py new file mode 100644 index 0000000..4f4d165 --- /dev/null +++ b/RepoCoder/make_dataset/config.py @@ -0,0 +1,15 @@ +''' +from the setup.py, we can get the mapping from "package_dir" +''' +REPO_PACKAGE_DIR = { + 'huggingface_diffusers': (['src'], []), + 'nerfstudio-project_nerfstudio': (), + 'awslabs_fortuna': (), + 'huggingface_evaluate': (['src'], []), + 'google_vizier': (), + 'PaddlePaddle_PaddleTS': (), + 'microsoft_RegionCLIP': (), + 'alibaba_FederatedScope': (), + 'pytorch_rl': (), + 'opendilab_ACE': () +} \ No newline at end of file diff --git a/RepoCoder/make_dataset/file_visitors.py b/RepoCoder/make_dataset/file_visitors.py new file mode 100644 index 0000000..fcd00e5 --- /dev/null +++ b/RepoCoder/make_dataset/file_visitors.py @@ -0,0 +1,224 @@ +import ast +import ipdb + +from ast_visitors import APIDefineVisitor, APICallVisitor, APIImportVisitor +from config import REPO_PACKAGE_DIR + + +class FileCallAPI: + def __init__(self, repo, source_code_files): + self.repo = repo + self.source_code_files = source_code_files + self.api_calls_by_file = dict() + + def _ast_processor_call(self, code, fpath_tuple): + visitor = APICallVisitor(fpath_tuple) + visitor.visit(ast.parse(code)) + return visitor + + def get_called_apis_by_file(self): + print(f'Collecting called APIs in {self.repo}') + for fpath_tuple, code in self.source_code_files.items(): + try: + visitor = self._ast_processor_call(code, fpath_tuple) + except Exception as e: + print(f'{fpath_tuple} fail to parse: {e}') + continue + self.api_calls_by_file[fpath_tuple] = visitor.called_apis + return self.api_calls_by_file + + +class FileDefinedAPI: + def __init__(self, repo, source_code_files): + self.repo = repo + self.source_code_files = source_code_files + self.defined_apis_by_file = dict() + + def _ast_processor_define(self, code, fpath_tuple): + tree = ast.parse(code) + visitor = APIDefineVisitor(fpath_tuple) + visitor.store_parent_node(tree) + visitor.visit(tree) + return visitor + + def get_defined_apis_by_file(self): + ''' + find defined apis in each python file + ''' + print(f"Finding defined APIs in {self.repo}") + for fpath_tuple, code in self.source_code_files.items(): + try: + visitor = self._ast_processor_define(code, fpath_tuple) + except Exception as e: + print(f'{fpath_tuple} fail to parse: {e}') + continue + self.defined_apis_by_file[fpath_tuple] = { + 'defined_classes': visitor.defined_classes, # dict + 'defined_outer_apis': visitor.defined_outer_apis, # list + } + return self.defined_apis_by_file + +class FileImportedAPI: + def __init__(self, repo, source_code_files, defined_apis_by_file): + self.repo = repo + self.source_code_files = source_code_files + self.defined_apis_by_file = defined_apis_by_file + self.imported_apis_by_file = dict() + + def _ast_processor_import(self, code, file_tuple): + file_module = build_file_module_from_file_tuple(self.repo, file_tuple) + tree = ast.parse(code) + visitor = APIImportVisitor(file_module, file_tuple) + visitor.visit(tree) + return visitor + + def get_imported_apis_by_file(self): + ''' + find imported apis in each python file + ''' + print(f"Finding imported APIs in {self.repo}") + for fpath_tuple, code in self.source_code_files.items(): + try: + visitor = self._ast_processor_import(code, fpath_tuple) + except Exception as e: + print(f'{fpath_tuple} fail to parse: {e}') + continue + # tring to locate the module of the imported api and the type of api (class or outer func) + self.imported_apis_by_file[fpath_tuple] = self._get_apis_info(visitor.imported_apis) + return self.imported_apis_by_file + + def _get_apis_info(self, imported_apis): + ''' + tring to locate the module of the imported api and the type of api (class or outer func) + ''' + imported_classes = list() + imported_outer_apis = list() + # those that are hard to decide the type (an init file refers to an init that refers to an init...) + imported_members = list() + imported_modules = list() + for imported_api in imported_apis: + imported_api_info = self._map_imported_api_to_fpath_tuple(imported_api) + if not imported_api_info: # not an intra-project api + continue + imported_api_type = imported_api_info['type'] + located_module_path_tuple = imported_api_info['module_path_tuple'] + module_path = imported_api_info['module_path'] # foo.bar + if imported_api_type == 'module': + imported_modules.append({ + 'module_name': module_path, + 'located_module_path_tuple': located_module_path_tuple + }) + elif imported_api_type == 'member': + located_module_defined_apis = self.defined_apis_by_file[located_module_path_tuple] + api_name = imported_api_info['api_name'] + if api_name in located_module_defined_apis['defined_classes']: + imported_classes.append({ + 'class_name': api_name, + 'located_module_path_tuple': located_module_path_tuple + }) + elif api_name in located_module_defined_apis['defined_outer_apis']: + imported_outer_apis.append({ + 'api_name': api_name, + 'located_module_path_tuple': located_module_path_tuple + }) + else: # TODO: can not handle recursive import now + imported_members.append({ + 'member_name': api_name, + 'located_module_path_tuple': located_module_path_tuple + }) + return { + 'imported_classes': imported_classes, + 'imported_outer_apis': imported_outer_apis, + 'imported_modules': imported_modules, + 'imported_members': imported_members # not necessarily a callable api + } + + + def _map_imported_api_to_fpath_tuple(self, imported_api): + ''' + return the most possible file module for the imported api + ''' + def __find_possible_fpath_tuple(imported_node, current_fpath_tuple): + located_file_tuples = [ + fpath_tuple for fpath_tuple in self.defined_apis_by_file.keys() + if f'.{build_file_module_from_file_tuple(self.repo, fpath_tuple)}'.endswith(f'.{imported_node}') + ] + if len(located_file_tuples) == 1: + return located_file_tuples[0] + elif len(located_file_tuples) < 1: + return None + elif len(located_file_tuples) > 1: + # when multiple files are found, we need to find the most possible one + score = [self._longest_common_subsequence('.'.join(current_fpath_tuple), '.'.join(fpath_tuple)) for fpath_tuple in located_file_tuples] + max_score_index = score.index(max(score)) + if score.count(max(score)) > 1 and len(located_file_tuples) == 2 and len([i for i in located_file_tuples if i[-1] == '__init__.py']) != 0: + # choose the one without init when there are two files with the same score + return [i for i in located_file_tuples if i[-1] != '__init__.py'][0] + if score.count(max(score)) > 1: + print(located_file_tuples) + print(imported_api) + ipdb.set_trace() + return located_file_tuples[max_score_index] + + api_name = imported_api['api_name'] # imported api can be a member or a module + api_path = imported_api['api_path'] # foo.bar + current_fpath_tuple = imported_api['current_fpath_tuple'] + deepest_node = '.'.join([i for i in [api_path, api_name] if i]) + located_file_tuple = __find_possible_fpath_tuple(deepest_node, current_fpath_tuple) + if located_file_tuple: + # imported api is a module + return { + 'type': 'module', + 'module_path_tuple': located_file_tuple, + 'module_path': deepest_node + } + if not api_path: # imported api is not a intra-project module + return None + located_file_tuple = __find_possible_fpath_tuple(api_path, current_fpath_tuple) + if not located_file_tuple: + # imported api is not a intra-project module + return None + if api_name == '*': # import the entire module + return { + 'type': 'module', + 'module_path_tuple': located_file_tuple, + 'module_path': api_path + } + # imported api is a member and we can find the module + return { + 'type': 'member', + 'module_path_tuple': located_file_tuple, + 'module_path': api_path, + 'api_name': api_name, + } + + def _longest_common_subsequence(self, text1, text2): + shorter, longer = text1, text2 + if len(text2) < len(text1): + shorter, longer = text2, text1 + common_length = 0 + for i in range(len(shorter)): + common_length += 1 if shorter[i] == longer[i] else 0 + return common_length + + +def build_file_module_from_file_tuple(repo, fpath_tuple): + # fpath_tuple: (repo_name, 'webui', 'launch.py') + assert fpath_tuple[0] == repo and fpath_tuple[-1].endswith('.py') + fpath_tuple = list(fpath_tuple) + module_name = fpath_tuple[-1][:-3] # launch + fpath_repo_excluded = fpath_tuple[1:] # ('webui', 'launch.py') + if REPO_PACKAGE_DIR[repo]: # need to modify module path + original_package_dirs = REPO_PACKAGE_DIR[repo][0] + if not original_package_dirs[0] in fpath_repo_excluded: # the package_dir is not in the file path + module_list = fpath_repo_excluded[:-1] + [module_name] + else: # the package_dir is in the file path + assert all([fpath_repo_excluded[i] == original_package_dirs[i] for i in range(len(original_package_dirs))]) + mapped_source_code_dir = REPO_PACKAGE_DIR[repo][1] + fpath_repo_excluded[len(original_package_dirs):] + module_list = mapped_source_code_dir[:-1] + [module_name] # ['launch'] if REPO_PACKAGE_DIR[repo][1] is ('webui', []) + else: + module_list = fpath_repo_excluded[:-1] + [module_name] # ['webui', 'launch'] + + if module_name == '__init__': # (repo_name, 'webui', '__init__.py') + module_list = module_list[:-1] # ['webui'] + return '.'.join(module_list) # webui.launch \ No newline at end of file diff --git a/RepoCoder/make_dataset/make_dataset_utils.py b/RepoCoder/make_dataset/make_dataset_utils.py new file mode 100644 index 0000000..dccad49 --- /dev/null +++ b/RepoCoder/make_dataset/make_dataset_utils.py @@ -0,0 +1,103 @@ +import os +import glob +import ipdb +import pickle +import json +import tiktoken + + +class CodexTokenizer(): + def __init__(self): + self.tokenizer = tiktoken.get_encoding("p50k_base") + + def tokenize(self, text): + return self.tokenizer.encode_ordinary(text) + + def decode(self, token_ids): + return self.tokenizer.decode(token_ids) + +class Tools: + @staticmethod + def read_code(fname): + with open(fname, 'r', encoding='utf8') as f: + return f.read() + + @staticmethod + def load_pickle(fname): + with open(fname, 'rb') as f: + return pickle.load(f) + + @staticmethod + def load_json(fname): + with open(fname, 'r', encoding='utf8') as f: + return json.load(f) + + @staticmethod + def dump_pickle(obj, fname): + with open(fname, 'wb') as f: + pickle.dump(obj, f) + + @staticmethod + def dump_jsonl(obj, fname): + with open(fname, 'w', encoding='utf8') as f: + for item in obj: + f.write(json.dumps(item) + '\n') + + @staticmethod + def dump_json(obj, fname): + with open(fname, 'w', encoding='utf8') as f: + json.dump(obj, f, indent=4) + + @staticmethod + def load_jsonl(fname): + with open(fname, 'r', encoding='utf8') as f: + lines = [] + for line in f: + lines.append(json.loads(line)) + return lines + + @staticmethod + def trim_context(tokenizer, previous_context_lines, context_max_tokens): + previous_total_lines = len(previous_context_lines) + previous_context = '\n'.join(previous_context_lines) + tokens = tokenizer.tokenize(previous_context) + decoded_context_total_lines = tokenizer.decode(tokens).count('\n') + 1 + try: + assert previous_total_lines == decoded_context_total_lines + except AssertionError: + ipdb.set_trace() + trimmed_tokens = tokens[-context_max_tokens:] + trimmed_context = tokenizer.decode(trimmed_tokens) + trimed_context_total_lines = trimmed_context.count('\n') + 1 + trimed_context_start_lineno = previous_total_lines - trimed_context_total_lines # 0-indexed lineno + return trimmed_context.splitlines(), trimed_context_start_lineno + + @staticmethod + def iterate_repository(base_dir, repo): + pattern = os.path.join(f'{base_dir}/{repo}', "**", "*.py") + files = glob.glob(pattern, recursive=True) + + skipped_files = [] + loaded_code_files = dict() + base_dir_list = os.path.normpath(base_dir).split(os.sep) + for fname in files: + try: + code = Tools.read_code(fname) + fpath_tuple = tuple(os.path.normpath(fname).split(os.sep)[len(base_dir_list):]) + loaded_code_files[fpath_tuple]= code + except Exception as e: + skipped_files.append((fname, e)) + continue + + if len(skipped_files) > 0: + print(f"Skipped {len(skipped_files)} out of {len(files)} files due to I/O errors") + for fname, e in skipped_files: + print(f"{fname}: {e}") + return loaded_code_files + + @staticmethod + def tokenize(code): + tokenizer = CodexTokenizer() + return tokenizer.tokenize(code) + + diff --git a/RepoCoder/make_dataset/random_benchmark.py b/RepoCoder/make_dataset/random_benchmark.py new file mode 100644 index 0000000..cabd231 --- /dev/null +++ b/RepoCoder/make_dataset/random_benchmark.py @@ -0,0 +1,122 @@ +import os +import random + +from make_dataset_utils import Tools, CodexTokenizer + + +class RandomHoleDigger: + def __init__(self, repo_base_dir, repo, context_max_tokens=2000, line_min_tokens=5, max_sample_per_repo=200): + self.source_code_files = Tools.iterate_repository(repo_base_dir, repo) + self.context_max_tokens = context_max_tokens + self.max_sample_per_repo = max_sample_per_repo + self.line_min_tokens = line_min_tokens + self.repo = repo + self.tokenizer = CodexTokenizer() + + def _get_line_types(self, lines): + line_types = dict() + in_multiline_comment = False + multiline_comment_start = "" + for lineno, line in enumerate(lines): + stripped_line = line.strip() + if not stripped_line: + line_types[lineno] = 'empty' + continue + if in_multiline_comment: + if stripped_line.endswith(multiline_comment_start): + in_multiline_comment = False + line_types[lineno] = 'comment' + elif stripped_line.startswith('"""') or stripped_line.startswith("'''"): + in_multiline_comment = True + multiline_comment_start = stripped_line[:3] + line_types[lineno] = 'comment' + elif stripped_line and stripped_line[0] == "#": + line_types[lineno] = 'comment' + else: + line_types[lineno] = 'code' + return line_types + + def _get_usable_lines(self, lines): + line_types = self._get_line_types(lines) + usable_lines = [] + for lineno, line_type in line_types.items(): + if line_type == 'code': + if lineno == 0: + continue + if line_types[lineno - 1] == 'empty': + continue + usable_lines.append(lineno) + return usable_lines + + def get_chosen_lines(self): + candidate_lines = [] + for fpath_tuple, code in self.source_code_files.items(): + code_lines = code.splitlines() + usable_lines = self._get_usable_lines(code_lines) + candidate_lines.extend([(fpath_tuple, line_no) for line_no in usable_lines]) + random.shuffle(candidate_lines) + chosen_lines = [] + chosen_line_strs = set() + for fpath_tuple, line_no in candidate_lines: + code_lines = self.source_code_files[fpath_tuple].splitlines() + line = code_lines[line_no] + if len(self.tokenizer.tokenize(line)) > self.context_max_tokens: + continue + if line.strip() in chosen_line_strs: + continue + chosen_line_strs.add(line.strip()) + chosen_lines.append({ + 'fpath_tuple': fpath_tuple, + 'line_no': line_no, + 'ground_truth': line, + 'code_lines': code_lines + }) + if len(chosen_lines) >= self.max_sample_per_repo: + break + return chosen_lines + + def _make_context(self, line): + previous_lines = line['code_lines'][:line['line_no']] + trimmed_lines, trimed_context_start_lineno = Tools.trim_context(self.tokenizer, previous_lines, self.context_max_tokens) + return '\n'.join(trimmed_lines), trimed_context_start_lineno + + def make_dataset(self): + test_data = [] + chosen_lines = self.get_chosen_lines() + for index, line in enumerate(chosen_lines): + prompt, trimed_context_start_lineno = self._make_context(line) + test_data.append({ + 'prompt': prompt, + 'metadata': { + 'task_id': f'{self.repo}/{index}', + 'ground_truth': line['ground_truth'], + 'fpath_tuple': line['fpath_tuple'], + 'context_start_lineno': trimed_context_start_lineno, + 'line_no': line['line_no'] + }}) + print(f'Generated {len(test_data)} samples for {self.repo}') + return test_data + +if __name__ == '__main__': + OUT_BASE_DIR = 'output' + REPO_BASE_DIR = 'downloaded_repos' + repos = [ + 'huggingface_diffusers', + 'nerfstudio-project_nerfstudio', + 'awslabs_fortuna', + 'huggingface_evaluate', + 'google_vizier', + 'PaddlePaddle_PaddleTS', + 'microsoft_RegionCLIP', + 'alibaba_FederatedScope', + 'pytorch_rl', + 'opendilab_ACE' + ] + lines = [] + for repo in repos: + print(f'Processing {repo}') + digger = RandomHoleDigger(REPO_BASE_DIR, repo) + lines += digger.make_dataset() + Tools.dump_jsonl(lines, os.path.join(OUT_BASE_DIR, 'ten-repos-random-line-completion.jsonl')) + + From a91938083f404c3aaf5ea348c08f45aa3ed69090 Mon Sep 17 00:00:00 2001 From: "fengji.zhang" Date: Sun, 18 Aug 2024 16:46:40 +0800 Subject: [PATCH 6/8] fix file path bug --- RepoCoder/run_pipeline.py | 36 +++++++++++++++++++++--------------- RepoCoder/utils.py | 3 +++ 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/RepoCoder/run_pipeline.py b/RepoCoder/run_pipeline.py index e44e29e..4470140 100644 --- a/RepoCoder/run_pipeline.py +++ b/RepoCoder/run_pipeline.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import os +import itertools os.environ["TOKENIZERS_PARALLELISM"] = "false" from make_window import MakeWindowWrapper @@ -9,7 +10,7 @@ from search_code import CodeSearchWrapper from build_prompt import BuildPromptWrapper -from utils import CONSTANTS, CodexTokenizer +from utils import CONSTANTS, CodexTokenizer, CodeGenTokenizer def make_repo_window(repos, window_sizes, slice_sizes): MakeWindowWrapper(None, repos, window_sizes, slice_sizes).window_for_repo_files() @@ -26,14 +27,16 @@ def run_RG1_and_oracle_method(benchmark, repos, window_sizes, slice_sizes): # search code for vanilla retrieval-augmented approach and ground truth CodeSearchWrapper('one-gram', benchmark, repos, window_sizes, slice_sizes).search_baseline_and_ground() # build prompt for vanilla retrieval-augmented approach and ground truth - tokenizer = CodexTokenizer - mode = CONSTANTS.rg - output_file_path = 'prompts/rg-one-gram-ws-20-ss-2.jsonl' - BuildPromptWrapper('one-gram', benchmark, repos, window_sizes, slice_sizes, tokenizer).build_first_search_prompt(mode, output_file_path) + tokenizer = CodeGenTokenizer + + for window_size, slice_size in itertools.product(window_sizes, slice_sizes): + mode = CONSTANTS.rg + output_file_path = f'prompts/rg-one-gram-ws-{window_size}-ss-{slice_size}.jsonl' + BuildPromptWrapper('one-gram', benchmark, repos, window_size, slice_size, tokenizer).build_first_search_prompt(mode, output_file_path) - mode = CONSTANTS.gt - output_file_path = 'prompts/gt-one-gram-ws-20-ss-2.jsonl' - BuildPromptWrapper('one-gram', benchmark, repos, window_sizes, slice_sizes, tokenizer).build_first_search_prompt(mode, output_file_path) + mode = CONSTANTS.gt + output_file_path = f'prompts/gt-one-gram-ws-{window_size}-ss-{slice_size}.jsonl' + BuildPromptWrapper('one-gram', benchmark, repos, window_size, slice_size, tokenizer).build_first_search_prompt(mode, output_file_path) def run_RepoCoder_method(benchmark, repos, window_sizes, slice_sizes, prediction_path): @@ -42,9 +45,10 @@ def run_RepoCoder_method(benchmark, repos, window_sizes, slice_sizes, prediction vectorizer = BagOfWords BuildVectorWrapper(benchmark, vectorizer, repos, window_sizes, slice_sizes).vectorize_prediction_windows(mode, prediction_path) CodeSearchWrapper('one-gram', benchmark, repos, window_sizes, slice_sizes).search_prediction(mode, prediction_path) - tokenizer = CodexTokenizer - output_file_path = 'prompts/repocoder-one-gram-ws-20-ss-2.jsonl' - BuildPromptWrapper('one-gram', benchmark, repos, window_sizes, slice_sizes, tokenizer).build_prediction_prompt(mode, prediction_path, output_file_path) + tokenizer = CodeGenTokenizer + for window_size, slice_size in itertools.product(window_sizes, slice_sizes): + output_file_path = f'prompts/repocoder-one-gram-ws-{window_size}-ss-{slice_size}.jsonl' + BuildPromptWrapper('one-gram', benchmark, repos, window_size, slice_size, tokenizer).build_prediction_prompt(mode, prediction_path, output_file_path) if __name__ == '__main__': @@ -64,9 +68,11 @@ def run_RepoCoder_method(benchmark, repos, window_sizes, slice_sizes, prediction # build window for the repos make_repo_window(repos, window_sizes, slice_sizes) - # build prompt for the RG1 and oracle methods + # build prompt for the RG1 and oracle methods, after building the prompts, you should run inferece and then evaluate the results run_RG1_and_oracle_method(CONSTANTS.api_benchmark, repos, window_sizes, slice_sizes) - # build prompt for the RepoCoder method - prediction_path = 'predictions/rg-one-gram-ws-20-ss-2_samples.0.jsonl' - run_RepoCoder_method(CONSTANTS.api_benchmark, repos, window_sizes, slice_sizes, prediction_path) \ No newline at end of file + ''' + before building prompt for the RepoCoder method, you need to run inference on the prompts of RG1 method + ''' + # prediction_path = 'predictions/rg-one-gram-ws-20-ss-2_samples.0.jsonl' + # run_RepoCoder_method(CONSTANTS.api_benchmark, repos, window_sizes, slice_sizes, prediction_path) \ No newline at end of file diff --git a/RepoCoder/utils.py b/RepoCoder/utils.py index 6233a3d..f667115 100644 --- a/RepoCoder/utils.py +++ b/RepoCoder/utils.py @@ -116,16 +116,19 @@ def load_pickle(fname): @staticmethod def dump_pickle(obj, fname): + os.makedirs(os.path.dirname(fname), exist_ok=True) with open(fname, 'wb') as f: pickle.dump(obj, f) @staticmethod def dump_json(obj, fname): + os.makedirs(os.path.dirname(fname), exist_ok=True) with open(fname, 'w', encoding='utf8') as f: json.dump(obj, f) @staticmethod def dump_jsonl(obj, fname): + os.makedirs(os.path.dirname(fname), exist_ok=True) with open(fname, 'w', encoding='utf8') as f: for item in obj: f.write(json.dumps(item) + '\n') From 6263cdfacb4018767cba8892094fc1aa9617e3bd Mon Sep 17 00:00:00 2001 From: "fengji.zhang" Date: Sun, 18 Aug 2024 17:14:33 +0800 Subject: [PATCH 7/8] change to codex default setting --- RepoCoder/run_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/RepoCoder/run_pipeline.py b/RepoCoder/run_pipeline.py index 4470140..002d73c 100644 --- a/RepoCoder/run_pipeline.py +++ b/RepoCoder/run_pipeline.py @@ -69,10 +69,10 @@ def run_RepoCoder_method(benchmark, repos, window_sizes, slice_sizes, prediction make_repo_window(repos, window_sizes, slice_sizes) # build prompt for the RG1 and oracle methods, after building the prompts, you should run inferece and then evaluate the results - run_RG1_and_oracle_method(CONSTANTS.api_benchmark, repos, window_sizes, slice_sizes) + run_RG1_and_oracle_method(CONSTANTS.short_api_benchmark, repos, window_sizes, slice_sizes) ''' before building prompt for the RepoCoder method, you need to run inference on the prompts of RG1 method ''' # prediction_path = 'predictions/rg-one-gram-ws-20-ss-2_samples.0.jsonl' - # run_RepoCoder_method(CONSTANTS.api_benchmark, repos, window_sizes, slice_sizes, prediction_path) \ No newline at end of file + # run_RepoCoder_method(CONSTANTS.short_api_benchmark, repos, window_sizes, slice_sizes, prediction_path) \ No newline at end of file From d1b55f698803b71462a47af76f40a220060b01e1 Mon Sep 17 00:00:00 2001 From: Fengji Zhang Date: Fri, 1 Nov 2024 21:03:26 +0800 Subject: [PATCH 8/8] fix retrieval bug --- RepoCoder/build_prompt.py | 57 ++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/RepoCoder/build_prompt.py b/RepoCoder/build_prompt.py index 585a31d..ac3168a 100644 --- a/RepoCoder/build_prompt.py +++ b/RepoCoder/build_prompt.py @@ -39,31 +39,37 @@ def _make_a_block(self, retrieved_context): token_len = len(tokenized_block) return block_str, token_len - def _make_an_extended_block(self, retrieved_context): + def _make_an_extended_block(self, task_metadata, retrieved_context): content, sim_score = retrieved_context metadata = content['metadata'] - # put the file path in the comment - assert metadata[0]['fpath_tuple'][0] == metadata[0]['repo'] - f_paths = ['/'.join(x['fpath_tuple'][1:]) for x in metadata] - f_paths_str = '\n'.join([f'# {f_path}' for f_path in f_paths]) - f_path_comment = f'# the below code fragment can be found in:' - # put code lines in the comment - original_code = Tools.read_code(os.path.join(FilePathBuilder.repo_base_dir, *metadata[0]['fpath_tuple'])) - code_lines = original_code.splitlines() - end_line_no = metadata[0]['end_line_no'] - window_size = metadata[0]['window_size'] - slice_size = metadata[0]['slice_size'] - new_end_line_no = min(end_line_no + window_size // slice_size, len(code_lines)) - new_start_line_no = max(0, new_end_line_no - window_size) - content_lines = code_lines[new_start_line_no:new_end_line_no] - content_lines_comment = [f'# {line}' for line in content_lines] - # aggregate the comment and the code lines - block_str = '\n'.join([f_path_comment, f_paths_str, self.seperator] + content_lines_comment + [self.seperator]) + '\n' - tokenized_block = self.tokenizer.tokenize(block_str) - token_len = len(tokenized_block) - return block_str, token_len + duplicate_num = len(metadata) # for those share the exact same code fragment from different files + for i in range(duplicate_num): + # put the file path in the comment + assert metadata[i]['fpath_tuple'][0] == metadata[i]['repo'] + f_paths = ['/'.join(x['fpath_tuple'][1:]) for x in metadata] + f_paths_str = '\n'.join([f'# {f_path}' for f_path in f_paths]) + f_path_comment = f'# the below code fragment can be found in:' + # put code lines in the comment + original_code = Tools.read_code(os.path.join(FilePathBuilder.repo_base_dir, *metadata[i]['fpath_tuple'])) + code_lines = original_code.splitlines() + end_line_no = metadata[i]['end_line_no'] + window_size = metadata[i]['window_size'] + slice_size = metadata[i]['slice_size'] + new_end_line_no = min(end_line_no + window_size // slice_size, len(code_lines)) + new_start_line_no = max(0, new_end_line_no - window_size) + if metadata[i]['fpath_tuple'] == tuple(task_metadata['fpath_tuple']) and new_end_line_no >= task_metadata['line_no']: + continue + content_lines = code_lines[new_start_line_no:new_end_line_no] + content_lines_comment = [f'# {line}' for line in content_lines] + # aggregate the comment and the code lines + block_str = '\n'.join([f_path_comment, f_paths_str, self.seperator] + content_lines_comment + [self.seperator]) + '\n' + tokenized_block = self.tokenizer.tokenize(block_str) + token_len = len(tokenized_block) + return block_str, token_len + else: + return '', 0 - def _build_prompt(self, mode, prompt, top_k_context): + def _build_prompt(self, mode, prompt, task_metadata, top_k_context): prepend_context = "# Here are some relevant code fragments from other files of the repo:\n" prepend_context += self.seperator + '\n' current_token_length = 20 # the length of the head_prompt, same for codex and codegen tokenizer @@ -73,7 +79,10 @@ def _build_prompt(self, mode, prompt, top_k_context): for retrieved_context in top_k_context[::-1]: if len(chosen_context) >= self.max_examples: break - block_str, token_len = make_block_func(retrieved_context) + kwargs = {'retrieved_context': retrieved_context} + if mode == CONSTANTS.rg: + kwargs['task_metadata'] = task_metadata + block_str, token_len = make_block_func(**kwargs) if current_token_length + token_len < self.max_retrieval_length: prepend_blocks.insert(0, block_str) current_token_length += token_len @@ -90,7 +99,7 @@ def build_2nd_stage_input_file(self, mode): task = self.tasks_by_task_id[task_id] old_prompt = task['prompt'] top_k_context = query_line['top_k_context'] - new_prompt, chosen_context = self._build_prompt(mode, old_prompt, top_k_context) + new_prompt, chosen_context = self._build_prompt(mode, old_prompt, task['metadata'], top_k_context) new_prompt_line = { 'prompt': new_prompt, 'metadata': task['metadata'],