publicweb.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package guardiand
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/certusone/wormhole/node/pkg/proto/publicrpc/v1"
  6. "github.com/certusone/wormhole/node/pkg/supervisor"
  7. "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
  8. "github.com/improbable-eng/grpc-web/go/grpcweb"
  9. "go.uber.org/zap"
  10. "golang.org/x/crypto/acme"
  11. "golang.org/x/crypto/acme/autocert"
  12. "google.golang.org/grpc"
  13. "net"
  14. "net/http"
  15. "strings"
  16. )
  17. func allowCORSWrapper(h http.Handler) http.Handler {
  18. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  19. if origin := r.Header.Get("Origin"); origin != "" {
  20. w.Header().Set("Access-Control-Allow-Origin", origin)
  21. if r.Method == "OPTIONS" && r.Header.Get("Access-Control-Request-Method") != "" {
  22. corsPreflightHandler(w, r)
  23. return
  24. }
  25. }
  26. h.ServeHTTP(w, r)
  27. })
  28. }
  29. func corsPreflightHandler(w http.ResponseWriter, r *http.Request) {
  30. headers := []string{
  31. "content-type",
  32. "accept",
  33. "x-user-agent",
  34. "x-grpc-web",
  35. "grpc-status",
  36. "grpc-message",
  37. "authorization",
  38. }
  39. w.Header().Set("Access-Control-Allow-Headers", strings.Join(headers, ","))
  40. methods := []string{"GET", "HEAD", "POST", "PUT", "DELETE"}
  41. w.Header().Set("Access-Control-Allow-Methods", strings.Join(methods, ","))
  42. }
  43. func publicwebServiceRunnable(
  44. logger *zap.Logger,
  45. listenAddr string,
  46. upstreamAddr string,
  47. grpcServer *grpc.Server,
  48. tlsHostname string,
  49. tlsProd bool,
  50. tlsCacheDir string,
  51. ) (supervisor.Runnable, error) {
  52. return func(ctx context.Context) error {
  53. conn, err := grpc.DialContext(
  54. ctx,
  55. fmt.Sprintf("unix:///%s", upstreamAddr),
  56. grpc.WithBlock(),
  57. grpc.WithInsecure())
  58. if err != nil {
  59. return fmt.Errorf("failed to dial upstream: %s", err)
  60. }
  61. gwmux := runtime.NewServeMux()
  62. err = publicrpcv1.RegisterPublicRPCServiceHandler(ctx, gwmux, conn)
  63. if err != nil {
  64. panic(err)
  65. }
  66. mux := http.NewServeMux()
  67. grpcWebServer := grpcweb.WrapServer(grpcServer)
  68. mux.Handle("/", allowCORSWrapper(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
  69. if grpcWebServer.IsGrpcWebRequest(req) {
  70. grpcWebServer.ServeHTTP(resp, req)
  71. } else {
  72. gwmux.ServeHTTP(resp, req)
  73. }
  74. })))
  75. srv := &http.Server{
  76. Handler: mux,
  77. }
  78. // TLS setup
  79. if tlsHostname != "" {
  80. logger.Info("provisioning Let's Encrypt certificate", zap.String("hostname", tlsHostname))
  81. var acmeApi string
  82. if tlsProd {
  83. logger.Info("using production Let's Encrypt server")
  84. acmeApi = autocert.DefaultACMEDirectory
  85. } else {
  86. logger.Info("using staging Let's Encrypt server")
  87. acmeApi = "https://acme-staging-v02.api.letsencrypt.org/directory"
  88. }
  89. certManager := autocert.Manager{
  90. Prompt: autocert.AcceptTOS,
  91. HostPolicy: autocert.HostWhitelist(tlsHostname),
  92. Cache: autocert.DirCache(tlsCacheDir),
  93. Client: &acme.Client{DirectoryURL: acmeApi},
  94. }
  95. srv.TLSConfig = certManager.TLSConfig()
  96. logger.Info("certificate provisioning configured")
  97. }
  98. var listener net.Listener
  99. // If listenAddr is prefixed by "sd:", look for a matching systemd socket.
  100. if strings.HasPrefix(listenAddr, "sd:") {
  101. listeners, err := getSDListeners()
  102. if err != nil {
  103. return fmt.Errorf("failed to get systemd listeners: %w", err)
  104. }
  105. addr := listenAddr[3:]
  106. for _, v := range listeners {
  107. logger.Debug("found systemd socket", zap.String("addr", v.Addr().String()))
  108. if v.Addr().String() == addr {
  109. listener = v
  110. }
  111. }
  112. if listener == nil {
  113. all := make([]string, len(listeners))
  114. for i := range listeners {
  115. all[i] = listeners[i].Addr().String()
  116. }
  117. return fmt.Errorf("no valid systemd listeners, got: %s", strings.Join(all, ","))
  118. }
  119. } else {
  120. listener, err = net.Listen("tcp", listenAddr)
  121. if err != nil {
  122. return fmt.Errorf("failed to listen: %v", err)
  123. }
  124. }
  125. supervisor.Signal(ctx, supervisor.SignalHealthy)
  126. errC := make(chan error)
  127. go func() {
  128. logger.Info("publicweb server listening", zap.String("addr", srv.Addr))
  129. if tlsHostname != "" {
  130. errC <- srv.ServeTLS(listener, "", "")
  131. } else {
  132. errC <- srv.Serve(listener)
  133. }
  134. }()
  135. select {
  136. case <-ctx.Done():
  137. // non-graceful shutdown
  138. if err := srv.Close(); err != nil {
  139. return err
  140. }
  141. return ctx.Err()
  142. case err := <-errC:
  143. return err
  144. }
  145. }, nil
  146. }