| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- package guardiand
- import (
- "context"
- "fmt"
- "github.com/certusone/wormhole/node/pkg/proto/publicrpc/v1"
- "github.com/certusone/wormhole/node/pkg/supervisor"
- "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
- "github.com/improbable-eng/grpc-web/go/grpcweb"
- "go.uber.org/zap"
- "golang.org/x/crypto/acme"
- "golang.org/x/crypto/acme/autocert"
- "google.golang.org/grpc"
- "net"
- "net/http"
- "strings"
- )
- func allowCORSWrapper(h http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if origin := r.Header.Get("Origin"); origin != "" {
- w.Header().Set("Access-Control-Allow-Origin", origin)
- if r.Method == "OPTIONS" && r.Header.Get("Access-Control-Request-Method") != "" {
- corsPreflightHandler(w, r)
- return
- }
- }
- h.ServeHTTP(w, r)
- })
- }
- func corsPreflightHandler(w http.ResponseWriter, r *http.Request) {
- headers := []string{
- "content-type",
- "accept",
- "x-user-agent",
- "x-grpc-web",
- "grpc-status",
- "grpc-message",
- "authorization",
- }
- w.Header().Set("Access-Control-Allow-Headers", strings.Join(headers, ","))
- methods := []string{"GET", "HEAD", "POST", "PUT", "DELETE"}
- w.Header().Set("Access-Control-Allow-Methods", strings.Join(methods, ","))
- }
- func publicwebServiceRunnable(
- logger *zap.Logger,
- listenAddr string,
- upstreamAddr string,
- grpcServer *grpc.Server,
- tlsHostname string,
- tlsProd bool,
- tlsCacheDir string,
- ) (supervisor.Runnable, error) {
- return func(ctx context.Context) error {
- conn, err := grpc.DialContext(
- ctx,
- fmt.Sprintf("unix:///%s", upstreamAddr),
- grpc.WithBlock(),
- grpc.WithInsecure())
- if err != nil {
- return fmt.Errorf("failed to dial upstream: %s", err)
- }
- gwmux := runtime.NewServeMux()
- err = publicrpcv1.RegisterPublicRPCServiceHandler(ctx, gwmux, conn)
- if err != nil {
- panic(err)
- }
- mux := http.NewServeMux()
- grpcWebServer := grpcweb.WrapServer(grpcServer)
- mux.Handle("/", allowCORSWrapper(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
- if grpcWebServer.IsGrpcWebRequest(req) {
- grpcWebServer.ServeHTTP(resp, req)
- } else {
- gwmux.ServeHTTP(resp, req)
- }
- })))
- srv := &http.Server{
- Handler: mux,
- }
- // TLS setup
- if tlsHostname != "" {
- logger.Info("provisioning Let's Encrypt certificate", zap.String("hostname", tlsHostname))
- var acmeApi string
- if tlsProd {
- logger.Info("using production Let's Encrypt server")
- acmeApi = autocert.DefaultACMEDirectory
- } else {
- logger.Info("using staging Let's Encrypt server")
- acmeApi = "https://acme-staging-v02.api.letsencrypt.org/directory"
- }
- certManager := autocert.Manager{
- Prompt: autocert.AcceptTOS,
- HostPolicy: autocert.HostWhitelist(tlsHostname),
- Cache: autocert.DirCache(tlsCacheDir),
- Client: &acme.Client{DirectoryURL: acmeApi},
- }
- srv.TLSConfig = certManager.TLSConfig()
- logger.Info("certificate provisioning configured")
- }
- var listener net.Listener
- // If listenAddr is prefixed by "sd:", look for a matching systemd socket.
- if strings.HasPrefix(listenAddr, "sd:") {
- listeners, err := getSDListeners()
- if err != nil {
- return fmt.Errorf("failed to get systemd listeners: %w", err)
- }
- addr := listenAddr[3:]
- for _, v := range listeners {
- logger.Debug("found systemd socket", zap.String("addr", v.Addr().String()))
- if v.Addr().String() == addr {
- listener = v
- }
- }
- if listener == nil {
- all := make([]string, len(listeners))
- for i := range listeners {
- all[i] = listeners[i].Addr().String()
- }
- return fmt.Errorf("no valid systemd listeners, got: %s", strings.Join(all, ","))
- }
- } else {
- listener, err = net.Listen("tcp", listenAddr)
- if err != nil {
- return fmt.Errorf("failed to listen: %v", err)
- }
- }
- supervisor.Signal(ctx, supervisor.SignalHealthy)
- errC := make(chan error)
- go func() {
- logger.Info("publicweb server listening", zap.String("addr", srv.Addr))
- if tlsHostname != "" {
- errC <- srv.ServeTLS(listener, "", "")
- } else {
- errC <- srv.Serve(listener)
- }
- }()
- select {
- case <-ctx.Done():
- // non-graceful shutdown
- if err := srv.Close(); err != nil {
- return err
- }
- return ctx.Err()
- case err := <-errC:
- return err
- }
- }, nil
- }
|