/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/handler.go

https://github.com/dotcloud/docker · Go · 212 lines · 177 code · 25 blank · 10 comment · 50 complexity · d32e9ee1710fa283bf87d4f63611c9cb MD5 · raw file

  1. package runtime
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/textproto"
  9. "github.com/golang/protobuf/proto"
  10. "github.com/grpc-ecosystem/grpc-gateway/internal"
  11. "google.golang.org/grpc/grpclog"
  12. )
  13. var errEmptyResponse = errors.New("empty response")
  14. // ForwardResponseStream forwards the stream from gRPC server to REST client.
  15. func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
  16. f, ok := w.(http.Flusher)
  17. if !ok {
  18. grpclog.Infof("Flush not supported in %T", w)
  19. http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
  20. return
  21. }
  22. md, ok := ServerMetadataFromContext(ctx)
  23. if !ok {
  24. grpclog.Infof("Failed to extract ServerMetadata from context")
  25. http.Error(w, "unexpected error", http.StatusInternalServerError)
  26. return
  27. }
  28. handleForwardResponseServerMetadata(w, mux, md)
  29. w.Header().Set("Transfer-Encoding", "chunked")
  30. w.Header().Set("Content-Type", marshaler.ContentType())
  31. if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
  32. HTTPError(ctx, mux, marshaler, w, req, err)
  33. return
  34. }
  35. var delimiter []byte
  36. if d, ok := marshaler.(Delimited); ok {
  37. delimiter = d.Delimiter()
  38. } else {
  39. delimiter = []byte("\n")
  40. }
  41. var wroteHeader bool
  42. for {
  43. resp, err := recv()
  44. if err == io.EOF {
  45. return
  46. }
  47. if err != nil {
  48. handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
  49. return
  50. }
  51. if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
  52. handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
  53. return
  54. }
  55. var buf []byte
  56. switch {
  57. case resp == nil:
  58. buf, err = marshaler.Marshal(errorChunk(streamError(ctx, mux.streamErrorHandler, errEmptyResponse)))
  59. default:
  60. result := map[string]interface{}{"result": resp}
  61. if rb, ok := resp.(responseBody); ok {
  62. result["result"] = rb.XXX_ResponseBody()
  63. }
  64. buf, err = marshaler.Marshal(result)
  65. }
  66. if err != nil {
  67. grpclog.Infof("Failed to marshal response chunk: %v", err)
  68. handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
  69. return
  70. }
  71. if _, err = w.Write(buf); err != nil {
  72. grpclog.Infof("Failed to send response chunk: %v", err)
  73. return
  74. }
  75. wroteHeader = true
  76. if _, err = w.Write(delimiter); err != nil {
  77. grpclog.Infof("Failed to send delimiter chunk: %v", err)
  78. return
  79. }
  80. f.Flush()
  81. }
  82. }
  83. func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
  84. for k, vs := range md.HeaderMD {
  85. if h, ok := mux.outgoingHeaderMatcher(k); ok {
  86. for _, v := range vs {
  87. w.Header().Add(h, v)
  88. }
  89. }
  90. }
  91. }
  92. func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
  93. for k := range md.TrailerMD {
  94. tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
  95. w.Header().Add("Trailer", tKey)
  96. }
  97. }
  98. func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
  99. for k, vs := range md.TrailerMD {
  100. tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
  101. for _, v := range vs {
  102. w.Header().Add(tKey, v)
  103. }
  104. }
  105. }
  106. // responseBody interface contains method for getting field for marshaling to the response body
  107. // this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
  108. type responseBody interface {
  109. XXX_ResponseBody() interface{}
  110. }
  111. // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
  112. func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
  113. md, ok := ServerMetadataFromContext(ctx)
  114. if !ok {
  115. grpclog.Infof("Failed to extract ServerMetadata from context")
  116. }
  117. handleForwardResponseServerMetadata(w, mux, md)
  118. handleForwardResponseTrailerHeader(w, md)
  119. contentType := marshaler.ContentType()
  120. // Check marshaler on run time in order to keep backwards compatibility
  121. // An interface param needs to be added to the ContentType() function on
  122. // the Marshal interface to be able to remove this check
  123. if typeMarshaler, ok := marshaler.(contentTypeMarshaler); ok {
  124. contentType = typeMarshaler.ContentTypeFromMessage(resp)
  125. }
  126. w.Header().Set("Content-Type", contentType)
  127. if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
  128. HTTPError(ctx, mux, marshaler, w, req, err)
  129. return
  130. }
  131. var buf []byte
  132. var err error
  133. if rb, ok := resp.(responseBody); ok {
  134. buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
  135. } else {
  136. buf, err = marshaler.Marshal(resp)
  137. }
  138. if err != nil {
  139. grpclog.Infof("Marshal error: %v", err)
  140. HTTPError(ctx, mux, marshaler, w, req, err)
  141. return
  142. }
  143. if _, err = w.Write(buf); err != nil {
  144. grpclog.Infof("Failed to write response: %v", err)
  145. }
  146. handleForwardResponseTrailer(w, md)
  147. }
  148. func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
  149. if len(opts) == 0 {
  150. return nil
  151. }
  152. for _, opt := range opts {
  153. if err := opt(ctx, w, resp); err != nil {
  154. grpclog.Infof("Error handling ForwardResponseOptions: %v", err)
  155. return err
  156. }
  157. }
  158. return nil
  159. }
  160. func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) {
  161. serr := streamError(ctx, mux.streamErrorHandler, err)
  162. if !wroteHeader {
  163. w.WriteHeader(int(serr.HttpCode))
  164. }
  165. buf, merr := marshaler.Marshal(errorChunk(serr))
  166. if merr != nil {
  167. grpclog.Infof("Failed to marshal an error: %v", merr)
  168. return
  169. }
  170. if _, werr := w.Write(buf); werr != nil {
  171. grpclog.Infof("Failed to notify error to client: %v", werr)
  172. return
  173. }
  174. }
  175. // streamError returns the payload for the final message in a response stream
  176. // that represents the given err.
  177. func streamError(ctx context.Context, errHandler StreamErrorHandlerFunc, err error) *StreamError {
  178. serr := errHandler(ctx, err)
  179. if serr != nil {
  180. return serr
  181. }
  182. // TODO: log about misbehaving stream error handler?
  183. return DefaultHTTPStreamErrorHandler(ctx, err)
  184. }
  185. func errorChunk(err *StreamError) map[string]proto.Message {
  186. return map[string]proto.Message{"error": (*internal.StreamError)(err)}
  187. }