From 6e623e79939431a1b723a860623c3578178f4348 Mon Sep 17 00:00:00 2001 From: Ali Date: Tue, 22 Jul 2025 10:51:52 +0400 Subject: [PATCH 1/2] Added option to generate a custom error for stringNameToValueMethod --- README.md | 4 +- enumer.go | 28 ++++++++- golden_test.go | 35 ++++++----- stringer.go | 10 +++- testdata/dayWithCustomError.golden | 93 ++++++++++++++++++++++++++++++ 5 files changed, 149 insertions(+), 21 deletions(-) create mode 100644 testdata/dayWithCustomError.golden diff --git a/README.md b/README.md index 084fcd8..4d0b338 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Enumer is a tool to generate Go code that adds useful methods to Go enums (constants with a specific type). It started as a fork of [Rob Pike’s Stringer tool](https://godoc.org/golang.org/x/tools/cmd/stringer) -maintained by [Álvaro López Espinosa](https://github.com/alvaroloes/enumer). +maintained by [Álvaro López Espinosa](https://github.com/alvaroloes/enumer). This was again forked here as (https://github.com/dmarkham/enumer) picking up where Álvaro left off. @@ -41,6 +41,8 @@ Flags: if true, alternative string values method will be generated. Default: false -yaml if true, yaml marshaling methods will be generated. Default: false + -customerror + if true, a custom error type will be generated. Default: false ``` diff --git a/enumer.go b/enumer.go index c234209..5b71f7a 100644 --- a/enumer.go +++ b/enumer.go @@ -2,7 +2,12 @@ package main import "fmt" +const customInvalidError = `// ErrInvalid%[1]s is a custom error type for %[1]s +var ErrInvalid%[1]s = errors.New("invalid value for %[1]s") +` + // Arguments to format are: +// // [1]: type name const stringNameToValueMethod = `// %[1]sString retrieves an enum value from the enum constants string name. // Throws an error if the param is not part of the enum. @@ -14,11 +19,12 @@ func %[1]sString(s string) (%[1]s, error) { if val, ok := _%[1]sNameToValueMap[strings.ToLower(s)]; ok { return val, nil } - return 0, fmt.Errorf("%%s does not belong to %[1]s values", s) + return 0, %[2]s } ` // Arguments to format are: +// // [1]: type name const stringValuesMethod = `// %[1]sValues returns all values of the enum func %[1]sValues() []%[1]s { @@ -27,6 +33,7 @@ func %[1]sValues() []%[1]s { ` // Arguments to format are: +// // [1]: type name const stringsMethod = `// %[1]sStrings returns a slice of all String values of the enum func %[1]sStrings() []string { @@ -37,6 +44,7 @@ func %[1]sStrings() []string { ` // Arguments to format are: +// // [1]: type name const stringBelongsMethodLoop = `// IsA%[1]s returns "true" if the value is listed in the enum definition. "false" otherwise func (i %[1]s) IsA%[1]s() bool { @@ -50,6 +58,7 @@ func (i %[1]s) IsA%[1]s() bool { ` // Arguments to format are: +// // [1]: type name const stringBelongsMethodSet = `// IsA%[1]s returns "true" if the value is listed in the enum definition. "false" otherwise func (i %[1]s) IsA%[1]s() bool { @@ -59,6 +68,7 @@ func (i %[1]s) IsA%[1]s() bool { ` // Arguments to format are: +// // [1]: type name const altStringValuesMethod = `func (%[1]s) Values() []string { return %[1]sStrings() @@ -70,7 +80,7 @@ func (g *Generator) buildAltStringValuesMethod(typeName string) { g.Printf(altStringValuesMethod, typeName) } -func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThreshold int) { +func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThreshold int, customError bool) { // At this moment, either "g.declareIndexAndNameVars()" or "g.declareNameVars()" has been called // Print the slice of values @@ -89,7 +99,16 @@ func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThresh g.printNamesSlice(runs, typeName, runsThreshold) // Print the basic extra methods - g.Printf(stringNameToValueMethod, typeName) + if customError { + g.Printf(customInvalidError, typeName) + } + + stringNameToValueErr := `fmt.Errorf("%%s does not belong to %s values", s)` + if customError { + stringNameToValueErr = `ErrInvalid%s` + } + g.Printf(stringNameToValueMethod, typeName, fmt.Sprintf(stringNameToValueErr, typeName)) + g.Printf(stringValuesMethod, typeName) g.Printf(stringsMethod, typeName) if len(runs) <= runsThreshold { @@ -144,6 +163,7 @@ func (g *Generator) printNamesSlice(runs [][]Value, typeName string, runsThresho } // Arguments to format are: +// // [1]: type name const jsonMethods = ` // MarshalJSON implements the json.Marshaler interface for %[1]s @@ -169,6 +189,7 @@ func (g *Generator) buildJSONMethods(runs [][]Value, typeName string, runsThresh } // Arguments to format are: +// // [1]: type name const textMethods = ` // MarshalText implements the encoding.TextMarshaler interface for %[1]s @@ -189,6 +210,7 @@ func (g *Generator) buildTextMethods(runs [][]Value, typeName string, runsThresh } // Arguments to format are: +// // [1]: type name const yamlMethods = ` // MarshalYAML implements a YAML Marshaler for %[1]s diff --git a/golden_test.go b/golden_test.go index 03479d2..870204e 100644 --- a/golden_test.go +++ b/golden_test.go @@ -68,6 +68,10 @@ var goldenWithPrefix = []Golden{ {"dayWithPrefix", dayIn}, } +var goldenWithCustomError = []Golden{ + {"dayWithCustomError", dayIn}, +} + var goldenTrimAndAddPrefix = []Golden{ {"dayTrimAndPrefix", trimPrefixIn}, } @@ -315,45 +319,48 @@ const ( func TestGolden(t *testing.T) { for _, test := range golden { - runGoldenTest(t, test, false, false, false, false, false, false, true, "", "") + runGoldenTest(t, test, false, false, false, false, false, false, true, false, "", "") } for _, test := range goldenJSON { - runGoldenTest(t, test, true, false, false, false, false, false, false, "", "") + runGoldenTest(t, test, true, false, false, false, false, false, false, false, "", "") } for _, test := range goldenText { - runGoldenTest(t, test, false, false, false, true, false, false, false, "", "") + runGoldenTest(t, test, false, false, false, true, false, false, false, false, "", "") } for _, test := range goldenYAML { - runGoldenTest(t, test, false, true, false, false, false, false, false, "", "") + runGoldenTest(t, test, false, true, false, false, false, false, false, false, "", "") } for _, test := range goldenSQL { - runGoldenTest(t, test, false, false, true, false, false, false, false, "", "") + runGoldenTest(t, test, false, false, true, false, false, false, false, false, "", "") } for _, test := range goldenJSONAndSQL { - runGoldenTest(t, test, true, false, true, false, false, false, false, "", "") + runGoldenTest(t, test, true, false, true, false, false, false, false, false, "", "") } for _, test := range goldenGQLGen { - runGoldenTest(t, test, false, false, false, false, false, true, false, "", "") + runGoldenTest(t, test, false, false, false, false, false, true, false, false, "", "") } for _, test := range goldenTrimPrefix { - runGoldenTest(t, test, false, false, false, false, false, false, false, "Day", "") + runGoldenTest(t, test, false, false, false, false, false, false, false, false, "Day", "") } for _, test := range goldenTrimPrefixMultiple { - runGoldenTest(t, test, false, false, false, false, false, false, false, "Day,Night", "") + runGoldenTest(t, test, false, false, false, false, false, false, false, false, "Day,Night", "") } for _, test := range goldenWithPrefix { - runGoldenTest(t, test, false, false, false, false, false, false, false, "", "Day") + runGoldenTest(t, test, false, false, false, false, false, false, false, false, "", "Day") + } + for _, test := range goldenWithCustomError { + runGoldenTest(t, test, false, false, false, false, false, false, false, true, "", "Day") } for _, test := range goldenTrimAndAddPrefix { - runGoldenTest(t, test, false, false, false, false, false, false, false, "Day", "Night") + runGoldenTest(t, test, false, false, false, false, false, false, false, false, "Day", "Night") } for _, test := range goldenLinecomment { - runGoldenTest(t, test, false, false, false, false, true, false, false, "", "") + runGoldenTest(t, test, false, false, false, false, true, false, false, false, "", "") } } func runGoldenTest(t *testing.T, test Golden, - generateJSON, generateYAML, generateSQL, generateText, linecomment, generateGQLGen, generateValuesMethod bool, + generateJSON, generateYAML, generateSQL, generateText, linecomment, generateGQLGen, generateValuesMethod bool, generateCustomError bool, trimPrefix string, prefix string) { var g Generator @@ -382,7 +389,7 @@ func runGoldenTest(t *testing.T, test Golden, if len(tokens) != 3 { t.Fatalf("%s: need type declaration on first line", test.name) } - g.generate(tokens[1], generateJSON, generateYAML, generateSQL, generateText, generateGQLGen, "noop", trimPrefix, prefix, linecomment, generateValuesMethod) + g.generate(tokens[1], generateJSON, generateYAML, generateSQL, generateText, generateGQLGen, "noop", trimPrefix, prefix, linecomment, generateCustomError, generateValuesMethod) got := string(g.format()) if got != loadGolden(test.name) { // Use this to help build a golden text when changes are needed diff --git a/stringer.go b/stringer.go index 33d3d09..57cf793 100644 --- a/stringer.go +++ b/stringer.go @@ -56,6 +56,7 @@ var ( trimPrefix = flag.String("trimprefix", "", "transform each item name by removing a prefix or comma separated list of prefixes. Default: \"\"") addPrefix = flag.String("addprefix", "", "transform each item name by adding a prefix. Default: \"\"") linecomment = flag.Bool("linecomment", false, "use line comment text as printed text when present") + customError = flag.Bool("customerror", false, "if true, add a custom error type. Default: false") ) var comments arrayFlags @@ -131,11 +132,14 @@ func main() { g.Printf("\t\"io\"\n") g.Printf("\t\"strconv\"\n") } + if *customError { + g.Printf("\t\"errors\"\n") + } g.Printf(")\n") // Run generate for each type. for _, typeName := range typs { - g.generate(typeName, *json, *yaml, *sql, *text, *gqlgen, *transformMethod, *trimPrefix, *addPrefix, *linecomment, *altValuesFunc) + g.generate(typeName, *json, *yaml, *sql, *text, *gqlgen, *transformMethod, *trimPrefix, *addPrefix, *linecomment, *customError, *altValuesFunc) } // Format the output. @@ -415,7 +419,7 @@ func (g *Generator) prefixValueNames(values []Value, prefix string) { // generate produces the String method for the named type. func (g *Generator) generate(typeName string, includeJSON, includeYAML, includeSQL, includeText, includeGQLGen bool, - transformMethod string, trimPrefix string, addPrefix string, lineComment bool, includeValuesMethod bool) { + transformMethod string, trimPrefix string, addPrefix string, lineComment bool, customError bool, includeValuesMethod bool) { values := make([]Value, 0, 100) for _, file := range g.pkg.files { file.lineComment = lineComment @@ -468,7 +472,7 @@ func (g *Generator) generate(typeName string, g.buildNoOpOrderChangeDetect(runs, typeName) - g.buildBasicExtras(runs, typeName, runsThreshold) + g.buildBasicExtras(runs, typeName, runsThreshold, customError) if includeJSON { g.buildJSONMethods(runs, typeName, runsThreshold) } diff --git a/testdata/dayWithCustomError.golden b/testdata/dayWithCustomError.golden new file mode 100644 index 0000000..5b964d3 --- /dev/null +++ b/testdata/dayWithCustomError.golden @@ -0,0 +1,93 @@ + +const _DayName = "DayMondayDayTuesdayDayWednesdayDayThursdayDayFridayDaySaturdayDaySunday" + +var _DayIndex = [...]uint8{0, 9, 19, 31, 42, 51, 62, 71} + +const _DayLowerName = "daymondaydaytuesdaydaywednesdaydaythursdaydayfridaydaysaturdaydaysunday" + +func (i Day) String() string { + if i < 0 || i >= Day(len(_DayIndex)-1) { + return fmt.Sprintf("Day(%d)", i) + } + return _DayName[_DayIndex[i]:_DayIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _DayNoOp() { + var x [1]struct{} + _ = x[Monday-(0)] + _ = x[Tuesday-(1)] + _ = x[Wednesday-(2)] + _ = x[Thursday-(3)] + _ = x[Friday-(4)] + _ = x[Saturday-(5)] + _ = x[Sunday-(6)] +} + +var _DayValues = []Day{Monday, Tuesday, Wednesday, Thursday, Friday, Saturday, Sunday} + +var _DayNameToValueMap = map[string]Day{ + _DayName[0:9]: Monday, + _DayLowerName[0:9]: Monday, + _DayName[9:19]: Tuesday, + _DayLowerName[9:19]: Tuesday, + _DayName[19:31]: Wednesday, + _DayLowerName[19:31]: Wednesday, + _DayName[31:42]: Thursday, + _DayLowerName[31:42]: Thursday, + _DayName[42:51]: Friday, + _DayLowerName[42:51]: Friday, + _DayName[51:62]: Saturday, + _DayLowerName[51:62]: Saturday, + _DayName[62:71]: Sunday, + _DayLowerName[62:71]: Sunday, +} + +var _DayNames = []string{ + _DayName[0:9], + _DayName[9:19], + _DayName[19:31], + _DayName[31:42], + _DayName[42:51], + _DayName[51:62], + _DayName[62:71], +} + +// ErrInvalidDay is a custom error type for Day +var ErrInvalidDay = errors.New("invalid value for Day") + +// DayString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func DayString(s string) (Day, error) { + if val, ok := _DayNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _DayNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, ErrInvalidDay +} + +// DayValues returns all values of the enum +func DayValues() []Day { + return _DayValues +} + +// DayStrings returns a slice of all String values of the enum +func DayStrings() []string { + strs := make([]string, len(_DayNames)) + copy(strs, _DayNames) + return strs +} + +// IsADay returns "true" if the value is listed in the enum definition. "false" otherwise +func (i Day) IsADay() bool { + for _, v := range _DayValues { + if i == v { + return true + } + } + return false +} From b69b324ff5b1e5c5620512c539e590da8b3148ad Mon Sep 17 00:00:00 2001 From: Ali Date: Fri, 1 Aug 2025 15:57:36 +0400 Subject: [PATCH 2/2] Use package-provided custom error type instead of generating custom error types - added a new public package containing a custom error type: InvalidEnumValueError - modified the code to return this new custom error type when -customerror flag is used - removed usage of deprecated ioutil library --- .gitignore | 2 +- README.md | 2 +- endtoend_test.go | 4 ++-- enumer.go | 18 ++++++------------ golden_test.go | 14 ++++---------- pkg/enumer/enumer.go | 5 +++++ stringer.go | 14 ++++++-------- testdata/dayWithCustomError.golden | 5 +---- 8 files changed, 26 insertions(+), 38 deletions(-) create mode 100644 pkg/enumer/enumer.go diff --git a/.gitignore b/.gitignore index 44f0924..7b44ef1 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,4 @@ coverage.txt .idea .vscode -enumer +./enumer diff --git a/README.md b/README.md index 4d0b338..3b49a58 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ Flags: -yaml if true, yaml marshaling methods will be generated. Default: false -customerror - if true, a custom error type will be generated. Default: false + if true, a custom error will be returned by the `String` function. Default: false ``` diff --git a/endtoend_test.go b/endtoend_test.go index 203ecd6..fa65968 100644 --- a/endtoend_test.go +++ b/endtoend_test.go @@ -4,6 +4,7 @@ // go command is not available on android +//go:build !android // +build !android package main @@ -12,7 +13,6 @@ import ( "fmt" "go/build" "io" - "io/ioutil" "os" "os/exec" "path/filepath" @@ -40,7 +40,7 @@ func init() { // binary panics if the String method for X is not correct, including for error cases. func TestEndToEnd(t *testing.T) { - dir, err := ioutil.TempDir("", "stringer") + dir, err := os.MkdirTemp("", "stringer") if err != nil { t.Fatal(err) } diff --git a/enumer.go b/enumer.go index 5b71f7a..44c4e92 100644 --- a/enumer.go +++ b/enumer.go @@ -1,10 +1,8 @@ package main -import "fmt" - -const customInvalidError = `// ErrInvalid%[1]s is a custom error type for %[1]s -var ErrInvalid%[1]s = errors.New("invalid value for %[1]s") -` +import ( + "fmt" +) // Arguments to format are: // @@ -99,15 +97,11 @@ func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThresh g.printNamesSlice(runs, typeName, runsThreshold) // Print the basic extra methods + stringNameToValueErr := fmt.Sprintf(`fmt.Errorf("%%s does not belong to %s values", s)`, typeName) if customError { - g.Printf(customInvalidError, typeName) - } - - stringNameToValueErr := `fmt.Errorf("%%s does not belong to %s values", s)` - if customError { - stringNameToValueErr = `ErrInvalid%s` + stringNameToValueErr = `enumer.InvalidEnumValueError` } - g.Printf(stringNameToValueMethod, typeName, fmt.Sprintf(stringNameToValueErr, typeName)) + g.Printf(stringNameToValueMethod, typeName, stringNameToValueErr) g.Printf(stringValuesMethod, typeName) g.Printf(stringsMethod, typeName) diff --git a/golden_test.go b/golden_test.go index 870204e..ef0b087 100644 --- a/golden_test.go +++ b/golden_test.go @@ -10,7 +10,6 @@ package main import ( - "io/ioutil" "os" "path/filepath" "strings" @@ -367,7 +366,7 @@ func runGoldenTest(t *testing.T, test Golden, file := test.name + ".go" input := "package test\n" + test.input - dir, err := ioutil.TempDir("", "stringer") + dir, err := os.MkdirTemp("", "stringer") if err != nil { t.Error(err) } @@ -379,7 +378,7 @@ func runGoldenTest(t *testing.T, test Golden, }() absFile := filepath.Join(dir, file) - err = ioutil.WriteFile(absFile, []byte(input), 0644) + err = os.WriteFile(absFile, []byte(input), 0644) if err != nil { t.Error(err) } @@ -394,7 +393,7 @@ func runGoldenTest(t *testing.T, test Golden, if got != loadGolden(test.name) { // Use this to help build a golden text when changes are needed //goldenFile := fmt.Sprintf("./testdata/%v.golden", test.name) - //err = ioutil.WriteFile(goldenFile, []byte(got), 0644) + //err = os.WriteFile(goldenFile, []byte(got), 0644) //if err != nil { // t.Error(err) //} @@ -403,12 +402,7 @@ func runGoldenTest(t *testing.T, test Golden, } func loadGolden(name string) string { - fh, err := os.Open("testdata/" + name + ".golden") - if err != nil { - return "" - } - defer fh.Close() - b, err := ioutil.ReadAll(fh) + b, err := os.ReadFile("testdata/" + name + ".golden") if err != nil { return "" } diff --git a/pkg/enumer/enumer.go b/pkg/enumer/enumer.go new file mode 100644 index 0000000..e24a14a --- /dev/null +++ b/pkg/enumer/enumer.go @@ -0,0 +1,5 @@ +package enumer + +import "errors" + +var InvalidEnumValueError = errors.New("invalid enum value") diff --git a/stringer.go b/stringer.go index 57cf793..6ed24dd 100644 --- a/stringer.go +++ b/stringer.go @@ -18,7 +18,6 @@ import ( "go/importer" "go/token" "go/types" - "io/ioutil" "log" "os" "path/filepath" @@ -56,7 +55,7 @@ var ( trimPrefix = flag.String("trimprefix", "", "transform each item name by removing a prefix or comma separated list of prefixes. Default: \"\"") addPrefix = flag.String("addprefix", "", "transform each item name by adding a prefix. Default: \"\"") linecomment = flag.Bool("linecomment", false, "use line comment text as printed text when present") - customError = flag.Bool("customerror", false, "if true, add a custom error type. Default: false") + customError = flag.Bool("customerror", false, "if true, a custom error will be returned by the `String` function. Default: false") ) var comments arrayFlags @@ -133,8 +132,9 @@ func main() { g.Printf("\t\"strconv\"\n") } if *customError { - g.Printf("\t\"errors\"\n") + g.Printf("\n\t\"github.com/dmarkham/enumer/pkg/enumer\"\n") } + g.Printf(")\n") // Run generate for each type. @@ -153,12 +153,11 @@ func main() { } // Write to tmpfile first - tmpFile, err := ioutil.TempFile(dir, fmt.Sprintf("%s_enumer_", typs[0])) + tmpFile, err := os.CreateTemp(dir, fmt.Sprintf("%s_enumer_", typs[0])) if err != nil { log.Fatalf("creating temporary file for output: %s", err) } - _, err = tmpFile.Write(src) - if err != nil { + if _, err = tmpFile.Write(src); err != nil { tmpFile.Close() os.Remove(tmpFile.Name()) log.Fatalf("writing output: %s", err) @@ -166,8 +165,7 @@ func main() { tmpFile.Close() // Rename tmpfile to output file - err = os.Rename(tmpFile.Name(), outputName) - if err != nil { + if err := os.Rename(tmpFile.Name(), outputName); err != nil { log.Fatalf("moving tempfile to output file: %s", err) } } diff --git a/testdata/dayWithCustomError.golden b/testdata/dayWithCustomError.golden index 5b964d3..2438ce4 100644 --- a/testdata/dayWithCustomError.golden +++ b/testdata/dayWithCustomError.golden @@ -54,9 +54,6 @@ var _DayNames = []string{ _DayName[62:71], } -// ErrInvalidDay is a custom error type for Day -var ErrInvalidDay = errors.New("invalid value for Day") - // DayString retrieves an enum value from the enum constants string name. // Throws an error if the param is not part of the enum. func DayString(s string) (Day, error) { @@ -67,7 +64,7 @@ func DayString(s string) (Day, error) { if val, ok := _DayNameToValueMap[strings.ToLower(s)]; ok { return val, nil } - return 0, ErrInvalidDay + return 0, enumer.InvalidEnumValueError } // DayValues returns all values of the enum