Skip to content
59 changes: 45 additions & 14 deletions router_serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (rootRouter *Router) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
closure.Routers = make([]*Router, 1, rootRouter.maxChildrenDepth)
closure.Routers[0] = rootRouter
closure.Contexts = make([]reflect.Value, 1, rootRouter.maxChildrenDepth)
closure.Contexts[0] = reflect.New(rootRouter.contextType)
closure.Contexts[0] = reflect.New(rootRouter.contextType.Type)
closure.currentMiddlewareLen = len(rootRouter.middleware)
closure.RootRouter = rootRouter
closure.Request.rootContext = closure.Contexts[0]
Expand Down Expand Up @@ -220,20 +220,53 @@ func contextsFor(contexts []reflect.Value, routers []*Router) []reflect.Value {

for i := 1; i < routersLen; i++ {
var ctx reflect.Value
if routers[i].contextType == routers[i-1].contextType {
if routers[i].contextType.Type == routers[i-1].contextType.Type {
ctx = contexts[i-1]
} else {
ctx = reflect.New(routers[i].contextType)
ctxType := routers[i].contextType.Type
// set the first field to the parent
f := reflect.Indirect(ctx).Field(0)
f.Set(contexts[i-1])
if routers[i].contextType.IsDerived {
ctx = createDrivedContext(contexts[i-1], ctxType)
} else {
ctxType = reflect.PtrTo(ctxType)
ctx = getMatchedParentContext(contexts[i-1], ctxType)
}
}
contexts = append(contexts, ctx)
}

return contexts
}

func createDrivedContext(context reflect.Value, neededType reflect.Type) reflect.Value {
ctx := reflect.New(neededType)
childCtx := ctx
for {
f := reflect.Indirect(childCtx).Field(0)
if f.Type() != context.Type() && f.Kind() == reflect.Ptr {
childCtx = reflect.New(f.Type().Elem())
f.Set(childCtx)
continue
} else {
f.Set(context)
break
}
}
return ctx
}

func getMatchedParentContext(context reflect.Value, neededType reflect.Type) reflect.Value {
if neededType != context.Type() {
for {
context = reflect.Indirect(context).Field(0)
if context.Type() == neededType {
break
}
}
}
return context
}

// If there's a panic in the root middleware (so that we don't have a route/target), then invoke the root handler or default.
// If there's a panic in other middleware, then invoke the target action's function.
// If there's a panic in the action handler, then invoke the target action's function.
Expand All @@ -250,19 +283,17 @@ func (rootRouter *Router) handlePanic(rw *appResponseWriter, req *Request, err i

for !targetRouter.errorHandler.IsValid() && targetRouter.parent != nil {
targetRouter = targetRouter.parent

// Need to set context to the next context, UNLESS the context is the same type.
curContextStruct := reflect.Indirect(context)
if targetRouter.contextType != curContextStruct.Type() {
context = curContextStruct.Field(0)
if reflect.Indirect(context).Type() != targetRouter.contextType {
panic("bug: shouldn't get here")
}
}
}
}

if targetRouter.errorHandler.IsValid() {
// Need to set context to the next context, UNLESS the context is the same type.
if _, err := validateContext(reflect.Indirect(reflect.New(targetRouter.contextType.Type)).Interface(), reflect.Indirect(context).Type()); err != nil {
panic(err)
}

ctxType := reflect.PtrTo(targetRouter.contextType.Type)
context = getMatchedParentContext(context, ctxType)
invoke(targetRouter.errorHandler, context, []reflect.Value{reflect.ValueOf(rw), reflect.ValueOf(req), reflect.ValueOf(err)})
} else {
http.Error(rw, DefaultPanicResponse, http.StatusInternalServerError)
Expand Down
81 changes: 53 additions & 28 deletions router_setup.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package web

import (
"errors"
"reflect"
"strings"
)
Expand All @@ -19,6 +20,11 @@ const (

var httpMethods = []httpMethod{httpMethodGet, httpMethodPost, httpMethodPut, httpMethodDelete, httpMethodPatch, httpMethodHead, httpMethodOptions}

type ContextSt struct {
Type reflect.Type
IsDerived bool //true if it's drived from main route, false if main route is drived from it
}

// Router implements net/http's Handler interface and is what you attach middleware, routes/handlers, and subrouters to.
type Router struct {
// Hierarchy:
Expand All @@ -27,7 +33,7 @@ type Router struct {
maxChildrenDepth int

// For each request we'll create one of these objects
contextType reflect.Type
contextType ContextSt

// Eg, "/" or "/admin". Any routes added to this router will be prefixed with this.
pathPrefix string
Expand Down Expand Up @@ -89,10 +95,10 @@ var emptyInterfaceType = reflect.TypeOf((*interface{})(nil)).Elem()
// whose purpose is to communicate type information. On each request, an instance of this
// context type will be automatically allocated and sent to handlers.
func New(ctx interface{}) *Router {
validateContext(ctx, nil)
// validateContext(ctx, nil)

r := &Router{}
r.contextType = reflect.TypeOf(ctx)
r.contextType = ContextSt{Type: reflect.TypeOf(ctx)}
r.pathPrefix = "/"
r.maxChildrenDepth = 1
r.root = make(map[httpMethod]*pathNode)
Expand All @@ -116,10 +122,14 @@ func NewWithPrefix(ctx interface{}, pathPrefix string) *Router {
// embed a pointer to the previous context in the first slot. You can also pass
// a pathPrefix that each route will have. If "" is passed, then no path prefix is applied.
func (r *Router) Subrouter(ctx interface{}, pathPrefix string) *Router {
validateContext(ctx, r.contextType)

// Create new router, link up hierarchy
newRouter := &Router{parent: r}
contextType, err := validateContext(ctx, r.contextType.Type)
if err != nil {
panic(err)
}
newRouter.contextType = *contextType
r.children = append(r.children, newRouter)

// Increment maxChildrenDepth if this is the first child of the router
Expand All @@ -131,7 +141,6 @@ func (r *Router) Subrouter(ctx interface{}, pathPrefix string) *Router {
}
}

newRouter.contextType = reflect.TypeOf(ctx)
newRouter.pathPrefix = appendPath(r.pathPrefix, pathPrefix)
newRouter.root = r.root

Expand All @@ -141,7 +150,7 @@ func (r *Router) Subrouter(ctx interface{}, pathPrefix string) *Router {
// Middleware adds the specified middleware tot he router and returns the router.
func (r *Router) Middleware(fn interface{}) *Router {
vfn := reflect.ValueOf(fn)
validateMiddleware(vfn, r.contextType)
validateMiddleware(vfn, r.contextType.Type)
if vfn.Type().NumIn() == 3 {
r.middleware = append(r.middleware, &middlewareHandler{Generic: true, GenericMiddleware: fn.(func(ResponseWriter, *Request, NextMiddlewareFunc))})
} else {
Expand All @@ -154,7 +163,7 @@ func (r *Router) Middleware(fn interface{}) *Router {
// Error sets the specified function as the error handler (when panics happen) and returns the router.
func (r *Router) Error(fn interface{}) *Router {
vfn := reflect.ValueOf(fn)
validateErrorHandler(vfn, r.contextType)
validateErrorHandler(vfn, r.contextType.Type)
r.errorHandler = vfn
return r
}
Expand All @@ -166,7 +175,7 @@ func (r *Router) NotFound(fn interface{}) *Router {
panic("You can only set a NotFoundHandler on the root router.")
}
vfn := reflect.ValueOf(fn)
validateNotFoundHandler(vfn, r.contextType)
validateNotFoundHandler(vfn, r.contextType.Type)
r.notFoundHandler = vfn
return r
}
Expand All @@ -178,7 +187,7 @@ func (r *Router) OptionsHandler(fn interface{}) *Router {
panic("You can only set an OptionsHandler on the root router.")
}
vfn := reflect.ValueOf(fn)
validateOptionsHandler(vfn, r.contextType)
validateOptionsHandler(vfn, r.contextType.Type)
r.optionsHandler = vfn
return r
}
Expand Down Expand Up @@ -220,7 +229,7 @@ func (r *Router) Options(path string, fn interface{}) *Router {

func (r *Router) addRoute(method httpMethod, path string, fn interface{}) *Router {
vfn := reflect.ValueOf(fn)
validateHandler(vfn, r.contextType)
validateHandler(vfn, r.contextType.Type)
fullPath := appendPath(r.pathPrefix, path)
route := &route{Method: method, Path: fullPath, Router: r}
if vfn.Type().NumIn() == 2 {
Expand Down Expand Up @@ -249,26 +258,40 @@ func (r *Router) depth() int {
// Private methods:
//

// Panics unless validation is correct
func validateContext(ctx interface{}, parentCtxType reflect.Type) {
ctxType := reflect.TypeOf(ctx)

if ctxType.Kind() != reflect.Struct {
panic("web: Context needs to be a struct type")
}

if parentCtxType != nil && parentCtxType != ctxType {
if ctxType.NumField() == 0 {
panic("web: Context needs to have first field be a pointer to parent context")
// validate contexts
func validateContext(ctx interface{}, parentCtxType reflect.Type) (*ContextSt, error) {
doCheck := func(ctxType reflect.Type, parentCtxType reflect.Type) error {
for {
if ctxType.Kind() == reflect.Ptr {
ctxType = ctxType.Elem()
}
if ctxType.Kind() != reflect.Struct {
if ctxType == reflect.TypeOf(ctx) {
return errors.New("web: Context needs to be a struct type\n " + ctxType.String())
}
return errors.New("web: Context needs to have first field be a pointer to parent context\n" +
"Main Context: " + parentCtxType.String() + " Given Context: " + reflect.TypeOf(ctx).String())

}
if ctxType == parentCtxType {
break
}
if ctxType.NumField() == 0 {
return errors.New("web: Context needs to have first field be a pointer to parent context")
}
ctxType = ctxType.Field(0).Type
}
return nil
}

fldType := ctxType.Field(0).Type

// Ensure fld is a pointer to parentCtxType
if fldType != reflect.PtrTo(parentCtxType) {
panic("web: Context needs to have first field be a pointer to parent context")
ctxType := reflect.TypeOf(ctx)
if err1 := doCheck(ctxType, parentCtxType); err1 != nil {
if err2 := doCheck(parentCtxType, ctxType); err2 != nil {
return nil, err1
}
return &ContextSt{ctxType, false}, nil
}
return &ContextSt{ctxType, true}, nil
}

// Panics unless fn is a proper handler wrt ctxType
Expand Down Expand Up @@ -338,8 +361,10 @@ func isValidHandler(vfn reflect.Value, ctxType reflect.Type, types ...reflect.Ty
} else if numIn == (typesLen + 1) {
// context, types
firstArgType := fnType.In(0)
if firstArgType != reflect.PtrTo(ctxType) && firstArgType != emptyInterfaceType {
return false
if firstArgType != emptyInterfaceType {
if _, err := validateContext(reflect.Indirect(reflect.New(firstArgType.Elem())).Interface(), ctxType); err != nil {
return false
}
}
typesStartIdx = 1
} else {
Expand Down