Skip to content

Commit 1b25caf

Browse files
committed
Address second iteration of comments.
1 parent 0216007 commit 1b25caf

File tree

6 files changed

+33
-30
lines changed

6 files changed

+33
-30
lines changed

gcp_variant_transforms/options/variant_transform_options.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ def add_arguments(self, parser):
185185
parser.add_argument(
186186
'--num_bigquery_write_shards',
187187
type=int, default=1,
188-
help=('This flag is deprecated and may be removed in future releases.'))
188+
help=('This flag is deprecated and will be removed in future '
189+
'releases.'))
189190
parser.add_argument(
190191
'--null_numeric_value_replacement',
191192
type=int,

gcp_variant_transforms/pipeline_common.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def parse_args(argv, command_line_options):
7373
known_args, pipeline_args = parser.parse_known_args(argv)
7474
for transform_options in options:
7575
transform_options.validate(known_args)
76-
_raise_error_on_invalid_flags(pipeline_args)
76+
_raise_error_on_invalid_flags(
77+
pipeline_args,
78+
known_args.output_table if hasattr(known_args, 'output_table') else None)
7779
if hasattr(known_args, 'input_pattern') or hasattr(known_args, 'input_file'):
7880
known_args.all_patterns = _get_all_patterns(
7981
known_args.input_pattern, known_args.input_file)
@@ -304,8 +306,8 @@ def write_headers(merged_header, file_path):
304306
vcf_header_io.WriteVcfHeaders(file_path))
305307

306308

307-
def _raise_error_on_invalid_flags(pipeline_args):
308-
# type: (List[str]) -> None
309+
def _raise_error_on_invalid_flags(pipeline_args, output_table):
310+
# type: (List[str], Any) -> None
309311
"""Raises an error if there are unrecognized flags."""
310312
parser = argparse.ArgumentParser()
311313
for cls in pipeline_options.PipelineOptions.__subclasses__():
@@ -318,6 +320,14 @@ def _raise_error_on_invalid_flags(pipeline_args):
318320
not known_pipeline_args.setup_file):
319321
raise ValueError('The --setup_file flag is required for DataflowRunner. '
320322
'Please provide a path to the setup.py file.')
323+
if output_table:
324+
if (not hasattr(known_pipeline_args, 'temp_location') or
325+
not known_pipeline_args.temp_location):
326+
raise ValueError('--temp_location is required for BigQuery imports.')
327+
if not known_pipeline_args.temp_location.startswith('gs://'):
328+
raise ValueError(
329+
'--temp_location must be valid GCS location for BigQuery imports')
330+
321331

322332

323333
def is_pipeline_direct_runner(pipeline):

gcp_variant_transforms/pipeline_common_test.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,21 +94,31 @@ def test_fail_on_invalid_flags(self):
9494
'gcp-variant-transforms-test',
9595
'--staging_location',
9696
'gs://integration_test_runs/staging']
97-
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
97+
pipeline_common._raise_error_on_invalid_flags(pipeline_args, None)
9898

9999
# Add Dataflow runner (requires --setup_file).
100100
pipeline_args.extend(['--runner', 'DataflowRunner'])
101101
with self.assertRaisesRegexp(ValueError, 'setup_file'):
102-
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
102+
pipeline_common._raise_error_on_invalid_flags(pipeline_args, None)
103103

104104
# Add setup.py (required for Variant Transforms run). This is now valid.
105105
pipeline_args.extend(['--setup_file', 'setup.py'])
106-
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
106+
pipeline_common._raise_error_on_invalid_flags(pipeline_args, None)
107+
108+
with self.assertRaisesRegexp(ValueError, '--temp_location is required*'):
109+
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')
110+
111+
pipeline_args.extend(['--temp_location', 'wrong_gcs'])
112+
with self.assertRaisesRegexp(ValueError, '--temp_location must be valid*'):
113+
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')
114+
115+
pipeline_args = pipeline_args[:-1] + ['gs://valid_bucket/temp']
116+
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')
107117

108118
# Add an unknown flag.
109119
pipeline_args.extend(['--unknown_flag', 'somevalue'])
110120
with self.assertRaisesRegexp(ValueError, 'Unrecognized.*unknown_flag'):
111-
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
121+
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')
112122

113123
def test_get_compression_type(self):
114124
vcf_metadata_list = [filesystem.FileMetadata(path, size) for

gcp_variant_transforms/transforms/sample_info_to_bigquery.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def __init__(self, output_table_prefix, sample_name_encoding, append=False):
6969
self._append = append
7070
self._sample_name_encoding = sample_name_encoding
7171
self._schema = sample_info_table_schema_generator.generate_schema()
72-
self._temp_location = temp_location
7372

7473
def expand(self, pcoll):
7574
return (pcoll
@@ -84,5 +83,4 @@ def expand(self, pcoll):
8483
beam.io.BigQueryDisposition.WRITE_APPEND
8584
if self._append
8685
else beam.io.BigQueryDisposition.WRITE_TRUNCATE),
87-
method=beam.io.WriteToBigQuery.Method.FILE_LOADS,
88-
custom_gcs_temp_location=self._temp_location))
86+
method=beam.io.WriteToBigQuery.Method.FILE_LOADS))

gcp_variant_transforms/transforms/variant_to_bigquery.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,5 +119,4 @@ def expand(self, pcoll):
119119
beam.io.BigQueryDisposition.WRITE_APPEND
120120
if self._append
121121
else beam.io.BigQueryDisposition.WRITE_TRUNCATE),
122-
method=beam.io.WriteToBigQuery.Method.FILE_LOADS,
123-
custom_gcs_temp_location=self._temp_location))
122+
method=beam.io.WriteToBigQuery.Method.FILE_LOADS))

gcp_variant_transforms/vcf_to_bq.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,6 @@ def _run_annotation_pipeline(known_args, pipeline_args):
389389
def _create_sample_info_table(pipeline, # type: beam.Pipeline
390390
pipeline_mode, # type: PipelineModes
391391
known_args, # type: argparse.Namespace,
392-
temp_directory, # str
393392
):
394393
# type: (...) -> None
395394
headers = pipeline_common.read_headers(
@@ -399,14 +398,8 @@ def _create_sample_info_table(pipeline, # type: beam.Pipeline
399398
_ = (headers | 'SampleInfoToBigQuery' >>
400399
sample_info_to_bigquery.SampleInfoToBigQuery(
401400
known_args.output_table,
402-
<<<<<<< HEAD
403401
SampleNameEncoding[known_args.sample_name_encoding],
404402
known_args.append))
405-
=======
406-
temp_directory,
407-
known_args.append,
408-
known_args.samples_span_multiple_files))
409-
>>>>>>> Address first iteration of comments.
410403

411404

412405
def run(argv=None):
@@ -415,8 +408,6 @@ def run(argv=None):
415408
logging.info('Command: %s', ' '.join(argv or sys.argv))
416409
known_args, pipeline_args = pipeline_common.parse_args(argv,
417410
_COMMAND_LINE_OPTIONS)
418-
if known_args.output_table and '--temp_location' not in pipeline_args:
419-
raise ValueError('--temp_location is required for BigQuery imports.')
420411
if known_args.auto_flags_experiment:
421412
_get_input_dimensions(known_args, pipeline_args)
422413

@@ -492,6 +483,7 @@ def run(argv=None):
492483
num_shards = 1
493484

494485
if known_args.output_table:
486+
<<<<<<< HEAD
495487
<<<<<<< HEAD
496488
schema_file = tempfile.mkstemp(prefix=known_args.output_table,
497489
suffix=_BQ_SCHEMA_FILE_SUFFIX)[1]
@@ -504,13 +496,6 @@ def run(argv=None):
504496
file_to_write.write(schema_json)
505497

506498
for i in range(num_shards):
507-
=======
508-
temp_directory = pipeline_options.PipelineOptions(pipeline_args).view_as(
509-
pipeline_options.GoogleCloudOptions).temp_location
510-
if not temp_directory:
511-
raise ValueError('--temp_location must be set when writing to BigQuery.')
512-
for i in range(num_partitions):
513-
>>>>>>> Address first iteration of comments.
514499
table_suffix = ''
515500
if sharding and sharding.get_shard_name(i):
516501
table_suffix = '_' + sharding.get_shard_name(i)
@@ -527,7 +512,7 @@ def run(argv=None):
527512
known_args.null_numeric_value_replacement)))
528513
if known_args.generate_sample_info_table:
529514
_create_sample_info_table(
530-
pipeline, pipeline_mode, known_args, temp_directory)
515+
pipeline, pipeline_mode, known_args)
531516

532517
if known_args.output_avro_path:
533518
# TODO(bashir2): Add an integration test that outputs to Avro files and

0 commit comments

Comments
 (0)