/violetear.go

https://github.com/nbari/violetear · Go · 312 lines · 195 code · 33 blank · 84 comment · 68 complexity · 735e971cbc6071db8b968f92fe32ba14 MD5 · raw file

  1. // Package violetear - HTTP router
  2. //
  3. // Basic example:
  4. //
  5. // package main
  6. //
  7. // import (
  8. // "fmt"
  9. // "github.com/nbari/violetear"
  10. // "log"
  11. // "net/http"
  12. // )
  13. //
  14. // func catchAll(w http.ResponseWriter, r *http.Request) {
  15. // fmt.Fprintf(w, r.URL.Path[1:])
  16. // }
  17. //
  18. // func helloWorld(w http.ResponseWriter, r *http.Request) {
  19. // fmt.Fprintf(w, r.URL.Path[1:])
  20. // }
  21. //
  22. // func handleUUID(w http.ResponseWriter, r *http.Request) {
  23. // fmt.Fprintf(w, r.URL.Path[1:])
  24. // }
  25. //
  26. // func main() {
  27. // router := violetear.New()
  28. // router.LogRequests = true
  29. // router.RequestID = "REQUEST_LOG_ID"
  30. //
  31. // router.AddRegex(":uuid", `[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}`)
  32. //
  33. // router.HandleFunc("*", catchAll)
  34. // router.HandleFunc("/hello", helloWorld, "GET,HEAD")
  35. // router.HandleFunc("/root/:uuid/item", handleUUID, "POST,PUT")
  36. //
  37. // srv := &http.Server{
  38. // Addr: ":8080",
  39. // Handler: router,
  40. // ReadTimeout: 5 * time.Second,
  41. // WriteTimeout: 7 * time.Second,
  42. // MaxHeaderBytes: 1 << 20,
  43. // }
  44. // log.Fatal(srv.ListenAndServe())
  45. // }
  46. //
  47. package violetear
  48. import (
  49. "context"
  50. "fmt"
  51. "log"
  52. "net/http"
  53. "strings"
  54. )
  55. // ParamsKey used for the context
  56. const (
  57. ParamsKey key = 0
  58. versionHeader = "application/vnd."
  59. )
  60. // key int is unexported to prevent collisions with context keys defined in
  61. // other packages.
  62. type key int
  63. // Router struct
  64. type Router struct {
  65. // dynamicRoutes map of dynamic routes and regular expressions
  66. dynamicRoutes dynamicSet
  67. // Routes to be matched
  68. routes *Trie
  69. // Logger
  70. Logger func(*ResponseWriter, *http.Request)
  71. // LogRequests yes or no
  72. LogRequests bool
  73. // NotFoundHandler configurable http.Handler which is called when no matching
  74. // route is found. If it is not set, http.NotFound is used.
  75. NotFoundHandler http.Handler
  76. // NotAllowedHandler configurable http.Handler which is called when method not allowed.
  77. NotAllowedHandler http.Handler
  78. // PanicHandler function to handle panics.
  79. PanicHandler http.HandlerFunc
  80. // RequestID name of the header to use or create.
  81. RequestID string
  82. // Verbose
  83. Verbose bool
  84. // Error resulted from building a route.
  85. err error
  86. }
  87. // New returns a new initialized router.
  88. func New() *Router {
  89. return &Router{
  90. dynamicRoutes: dynamicSet{},
  91. routes: &Trie{},
  92. Logger: logger,
  93. Verbose: true,
  94. }
  95. }
  96. // Handle registers the handler for the given pattern (path, http.Handler, methods).
  97. func (r *Router) Handle(path string, handler http.Handler, httpMethods ...string) *Trie {
  98. var version string
  99. if i := strings.Index(path, "#"); i != -1 {
  100. version = path[i+1:]
  101. path = path[:i]
  102. }
  103. pathParts := r.splitPath(path)
  104. // search for dynamic routes
  105. for _, p := range pathParts {
  106. if strings.HasPrefix(p, ":") {
  107. if _, ok := r.dynamicRoutes[p]; !ok {
  108. r.err = fmt.Errorf("[%s] not found, need to add it using AddRegex(%q, `your regex`", p, p)
  109. return nil
  110. }
  111. }
  112. }
  113. // if no methods, accept ALL
  114. methods := "ALL"
  115. if len(httpMethods) > 0 && len(strings.TrimSpace(httpMethods[0])) > 0 {
  116. methods = httpMethods[0]
  117. }
  118. if r.Verbose {
  119. log.Printf("Adding path: %s [%s] %s", path, methods, version)
  120. }
  121. trie, err := r.routes.Set(pathParts, handler, methods, version)
  122. if err != nil {
  123. r.err = err
  124. return nil
  125. }
  126. return trie
  127. }
  128. // HandleFunc add a route to the router (path, http.HandlerFunc, methods)
  129. func (r *Router) HandleFunc(path string, handler http.HandlerFunc, httpMethods ...string) *Trie {
  130. return r.Handle(path, handler, httpMethods...)
  131. }
  132. // AddRegex adds a ":named" regular expression to the dynamicRoutes
  133. func (r *Router) AddRegex(name, regex string) error {
  134. return r.dynamicRoutes.Set(name, regex)
  135. }
  136. // MethodNotAllowed default handler for 405
  137. func (r *Router) MethodNotAllowed() http.HandlerFunc {
  138. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  139. http.Error(w,
  140. http.StatusText(http.StatusMethodNotAllowed),
  141. http.StatusMethodNotAllowed,
  142. )
  143. })
  144. }
  145. // checkMethod check if request method is allowed or not
  146. func (r *Router) checkMethod(node *Trie, method string) http.Handler {
  147. for _, h := range node.Handler {
  148. if h.Method == "ALL" {
  149. return h.Handler
  150. }
  151. if h.Method == method {
  152. return h.Handler
  153. }
  154. }
  155. if r.NotAllowedHandler != nil {
  156. return r.NotAllowedHandler
  157. }
  158. return r.MethodNotAllowed()
  159. }
  160. // dispatch request
  161. func (r *Router) dispatch(node *Trie, key, path, method, version string, leaf bool, params Params) (http.Handler, Params) {
  162. catchall := false
  163. if node.name != "" {
  164. if params == nil {
  165. params = Params{}
  166. }
  167. params.Add("rname", node.name)
  168. }
  169. if len(node.Handler) > 0 && leaf {
  170. return r.checkMethod(node, method), params
  171. } else if node.HasRegex {
  172. for _, n := range node.Node {
  173. if strings.HasPrefix(n.path, ":") {
  174. rx := r.dynamicRoutes[n.path]
  175. if rx.MatchString(key) {
  176. // add param to context
  177. if params == nil {
  178. params = Params{}
  179. }
  180. params.Add(n.path, key)
  181. node, key, path, leaf := node.Get(n.path+path, version)
  182. return r.dispatch(node, key, path, method, version, leaf, params)
  183. }
  184. }
  185. }
  186. if node.HasCatchall {
  187. catchall = true
  188. }
  189. } else if node.HasCatchall {
  190. catchall = true
  191. }
  192. if catchall {
  193. for _, n := range node.Node {
  194. if n.path == "*" {
  195. // add "*" to context
  196. if params == nil {
  197. params = Params{}
  198. }
  199. params.Add("*", key)
  200. if n.name != "" {
  201. params.Add("rname", n.name)
  202. }
  203. return r.checkMethod(n, method), params
  204. }
  205. }
  206. }
  207. // NotFound
  208. if r.NotFoundHandler != nil {
  209. return r.NotFoundHandler, params
  210. }
  211. return http.NotFoundHandler(), params
  212. }
  213. // ServeHTTP dispatches the handler registered in the matched path
  214. func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
  215. // panic handler
  216. defer func() {
  217. if err := recover(); err != nil {
  218. log.Printf("panic: %s", err)
  219. if r.PanicHandler != nil {
  220. r.PanicHandler(w, req)
  221. } else {
  222. http.Error(w, http.StatusText(500), http.StatusInternalServerError)
  223. }
  224. }
  225. }()
  226. // Request-ID
  227. var rid string
  228. if r.RequestID != "" {
  229. if rid = req.Header.Get(r.RequestID); rid != "" {
  230. w.Header().Set(r.RequestID, rid)
  231. }
  232. }
  233. // wrap ResponseWriter
  234. var ww *ResponseWriter
  235. if r.LogRequests {
  236. ww = NewResponseWriter(w, rid)
  237. }
  238. // set version based on the value of "Accept: application/vnd.*"
  239. version := req.Header.Get("Accept")
  240. if i := strings.LastIndex(version, versionHeader); i != -1 {
  241. version = version[len(versionHeader)+i:]
  242. } else {
  243. version = ""
  244. }
  245. // query the path from left to right
  246. node, key, path, leaf := r.routes.Get(req.URL.Path, version)
  247. // dispatch the request
  248. h, p := r.dispatch(node, key, path, req.Method, version, leaf, nil)
  249. // dispatch request
  250. if r.LogRequests {
  251. if p == nil {
  252. h.ServeHTTP(ww, req)
  253. } else {
  254. h.ServeHTTP(ww, req.WithContext(context.WithValue(req.Context(), ParamsKey, p)))
  255. }
  256. r.Logger(ww, req)
  257. } else {
  258. if p == nil {
  259. h.ServeHTTP(w, req)
  260. } else {
  261. h.ServeHTTP(w, req.WithContext(context.WithValue(req.Context(), ParamsKey, p)))
  262. }
  263. }
  264. }
  265. // splitPath returns an slice of the path
  266. func (r *Router) splitPath(p string) []string {
  267. pathParts := strings.FieldsFunc(p, func(c rune) bool {
  268. return c == '/'
  269. })
  270. // root (empty slice)
  271. if len(pathParts) == 0 {
  272. return []string{"/"}
  273. }
  274. return pathParts
  275. }
  276. // GetError returns an error resulted from building a route, if any.
  277. func (r *Router) GetError() error {
  278. return r.err
  279. }