diff --git a/mcp/server.go b/mcp/server.go index 254c2d5e..a84d6e2a 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -239,7 +239,7 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { func() bool { s.tools.add(st); return true }) } -func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { +func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out], schemaOpts *jsonschema.ForOptions) (*Tool, ToolHandler, error) { tt := *t // Special handling for an "any" input: treat as an empty object. @@ -248,7 +248,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan } var inputResolved *jsonschema.Resolved - if _, err := setSchema[In](&tt.InputSchema, &inputResolved); err != nil { + if _, err := setSchema[In](&tt.InputSchema, &inputResolved, schemaOpts); err != nil { return nil, nil, fmt.Errorf("input schema: %w", err) } @@ -263,7 +263,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan ) if t.OutputSchema != nil || reflect.TypeFor[Out]() != reflect.TypeFor[any]() { var err error - elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved) + elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved, schemaOpts) if err != nil { return nil, nil, fmt.Errorf("output schema: %v", err) } @@ -366,7 +366,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan // // TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we // should have a jsonschema.Zero(schema) helper? -func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err error) { +func setSchema[T any](sfield *any, rfield **jsonschema.Resolved, schemaOpts *jsonschema.ForOptions) (zero any, err error) { var internalSchema *jsonschema.Schema if *sfield == nil { rt := reflect.TypeFor[T]() @@ -374,8 +374,7 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err rt = rt.Elem() zero = reflect.Zero(rt).Interface() } - // TODO: we should be able to pass nil opts here. - internalSchema, err = jsonschema.ForType(rt, &jsonschema.ForOptions{}) + internalSchema, err = jsonschema.ForType(rt, schemaOpts) if err == nil { *sfield = internalSchema } @@ -389,6 +388,20 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err return zero, err } +// AddToolOption is an option for the AddTool function. +type AddToolOption func(*addToolOptions) + +type addToolOptions struct { + schemaOpts *jsonschema.ForOptions +} + +// WithSchemaOptions returns an AddToolOption that sets options for schema inference. +func WithSchemaOptions(opts *jsonschema.ForOptions) AddToolOption { + return func(ato *addToolOptions) { + ato.schemaOpts = opts + } +} + // AddTool adds a tool and typed tool handler to the server. // // If the tool's input schema is nil, it is set to the schema inferred from the @@ -408,8 +421,14 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err // Unlike [Server.AddTool], AddTool does a lot automatically, and forces // tools to conform to the MCP spec. See [ToolHandlerFor] for a detailed // description of this automatic behavior. -func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { - tt, hh, err := toolForErr(t, h) +func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out], opts ...AddToolOption) { + o := addToolOptions{ + schemaOpts: &jsonschema.ForOptions{}, + } + for _, opt := range opts { + opt(&o) + } + tt, hh, err := toolForErr(t, h, o.schemaOpts) if err != nil { panic(fmt.Sprintf("AddTool: tool %q: %v", t.Name, err)) } diff --git a/mcp/server_test.go b/mcp/server_test.go index d8c0df65..41314d86 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -562,7 +562,7 @@ func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out th := func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) { return nil, out, nil } - gott, goth, err := toolForErr(tool, th) + gott, goth, err := toolForErr(tool, th, &jsonschema.ForOptions{}) if err != nil { t.Fatal(err) }