Bindless Kubernetes port-forwarding in Golang

February 9, 2023

Update: We’ve released this code as a library. https://github.com/microcumulus/k8s-portforward-conn gives an easy way to port forward in Go without actually binding to a port on the local machine.

For the services we’re building at microcumulus, we have a need to get access to remote services that are not (yet) exposed outside the cluster. It’s a fairly straightforward thing to do in bash: just kubectl port-forward pod 8080:8080 for example, and you can then hit localhost:8080. When writing code that we expect to be highly concurrent, though, we don’t want to bother with the potential port overlaps and overhead of running a local bind and a second net.Dial.

We know kubectl is doing this in Go, and ideally we’d like to do the same but just getting back a net.Conn that directly tunnels to the remote pod. So, thus begins a dive into the kubernetes codebase; some of which has changed in the last few years to make this code path more obvious.

If you’re curious what the end result looks like, check out the code below. We built this out of a) a desire to use appropriately licensed open source code , and b) a desire to not need a step that binds to local ports just to return a usable net.Conn. Ultimately the steps we discovered are as follows:

  1. Keep a reference to the rest.Config around; it’s not easy to get from a kubernetes.Clientset, but the opposite (clientset from config) is trivial.
  2. Use k8s.io/client-go/transport/spdy to get a spdy.RoundTripper and spdy.Dialer from the rest.Config.
  3. Use a RESTClient to build the URL to the appropriate portforward subresource. Or just Sprintf(/api/v1/%s/pods/%s/portforward, pod.Name pod.Namespace), and use your fresh Dialer to dial that URL and get back an httpstream.Connection.
  4. Build some headers for two httpstream.Stream substreams: an error substream and a data substream, and use your Connection to get those streams.
  5. Implement net.Conn in such a way that it passes data through to the data substream, and checks errors on both the error stream and the data stream.

This approach is not perfect yet, but it’s definitely the best one we’ve found yet. There is a little left to be desired on the error handling side, but ultimately this code works quite well in a retry/backoff context (highly recommend cenkalti/backoff), and will likely evolve, perhaps into its own library.

Code:

import (
	"context"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"time"

	corev1 "k8s.io/api/core/v1"
	v1 "k8s.io/api/core/v1"
	"k8s.io/apimachinery/pkg/util/httpstream"
	"k8s.io/client-go/kubernetes"
	"k8s.io/client-go/rest"
	"k8s.io/client-go/tools/portforward"
	"k8s.io/client-go/transport/spdy"
)

func portForward(ctx context.Context, rc *rest.Config, pod corev1.Pod, port string) (net.Conn, error) {
	cs, err := kubernetes.NewForConfig(rc)
	if err != nil {
		return nil, fmt.Errorf("error creating http client: %w", err)
	}

	req := cs.RESTClient().
		Post().
		Prefix("api/v1").
		Resource("pods").
		Name(pod.Name).
		Namespace(pod.Namespace).
		SubResource("portforward")

	transport, upgrader, err := spdy.RoundTripperFor(rc)
	if err != nil {
		return nil, fmt.Errorf("error getting transport/upgrader from restconfig: %w", err)
	}

	dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", req.URL())
	conn, _, err := dialer.Dial(portforward.PortForwardProtocolV1Name)
	if err != nil {
		return nil, fmt.Errorf("error dialing for conn %w", err)
	}

	headers := http.Header{}
	headers.Set(v1.StreamType, v1.StreamTypeError)
	headers.Set(v1.PortHeader, port)
	headers.Set(v1.PortForwardRequestIDHeader, "1")

	errorStream, err := conn.CreateStream(headers)
	if err != nil {
		return nil, fmt.Errorf("error creating err stream: %w", err)
	}
	// we're not writing to this stream
	errorStream.Close()

	headers.Set(v1.StreamType, v1.StreamTypeData)
	dataStream, err := conn.CreateStream(headers)
	if err != nil {
		return nil, fmt.Errorf("error creating data stream: %w", err)
	}

	fc := &fakeConn{
		parent: conn,
		port:   port,
		err:    errorStream,
		errch:  make(chan error),
		data:   dataStream,
        pod: pod,
	}
	go fc.watchErr(ctx)

	return fc, nil
}



// This is a FakeAddr type used just in case anything asks for the net.Addr on
// either side of this "network connection." It's there for debug and helps to 
// show that the source is memory and the destination is a k8s pod in a specific
// namespace. `Network` returns "memory" because it's in-memory rather than tcp/udp.
type fakeAddr string
func (f fakeAddr) Network() string {
	return "memory"
}
func (f fakeAddr) String() string {
	return string(f)
}



// FakeConn is the guts of our connection. Most of this code is for handling
// channels and the fact that two things may error, resulting in a problem for
// our callers.
type fakeConn struct {
	parent    httpstream.Connection
	data, err httpstream.Stream
	errch     chan error
	port      string
    pod       v1.Pod
}

func (f *fakeConn) watchErr(ctx context.Context) {
	// This should only return if an err comes back.
	bs, err := io.ReadAll(f.err)
	if err != nil {
		select {
		case <-ctx.Done():
		case f.errch <- fmt.Errorf("error during read: %w", err):
		}
	}
	if len(bs) > 0 {
		select {
		case <-ctx.Done():
		case f.errch <- fmt.Errorf("error during read: %s", string(bs)):
		}
	}
}

func (f *fakeConn) Read(b []byte) (n int, err error) {
	select {
	case err := <-f.errch:
		return 0, err
	default:
	}
	return f.data.Read(b)
}

func (f *fakeConn) Write(b []byte) (n int, err error) {
	select {
	case err := <-f.errch:
		return 0, err
	default:
	}
	return f.data.Write(b)
}

func (f *fakeConn) Close() error {
	var errs []error
	select {
	case err := <-f.errch:
		if err != nil {
			errs = append(errs, err)
		}
	default:
	}
	err := f.data.Close()
	if err != nil {
		errs = append(errs, err)
	}
	f.parent.RemoveStreams(f.data, f.err)
	err = f.parent.Close()
	if err != nil {
		errs = append(errs, err)
	}
	return errors.Join(errs...)
}

func (f *fakeConn) LocalAddr() net.Addr {
	return fakeAddr("memory:" + f.port)
}

func (f *fakeConn) RemoteAddr() net.Addr {
	return fakeAddr(fmt.Sprintf("k8s/%s/%s:%s", f.pod.Namespace, f.pod.Name, f.port))
}

func (f *fakeConn) SetDeadline(t time.Time) error {
	f.parent.SetIdleTimeout(time.Until(t))
	return nil
}

func (f *fakeConn) SetReadDeadline(t time.Time) error {
	f.parent.SetIdleTimeout(time.Until(t))
	return nil
}

func (f *fakeConn) SetWriteDeadline(t time.Time) error {
	f.parent.SetIdleTimeout(time.Until(t))
	return nil
}