Skip to content

Commit

Permalink
Merge pull request #378 from cyberb/main
Browse files Browse the repository at this point in the history
unix socket support
  • Loading branch information
kegsay committed Nov 17, 2023
2 parents b6437ef + bf477c1 commit 62d3798
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ node_modules
# Go workspaces
go.work
go.work.sum
.idea
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ The Sliding Sync proxy requires some environment variables set to function. They

Here is a short description of each, as of writing:
```
SYNCV3_SERVER Required. The destination homeserver to talk to (CS API HTTPS URL) e.g 'https://matrix-client.matrix.org'
SYNCV3_SERVER Required. The destination homeserver to talk to (CS API HTTPS URL) e.g 'https://matrix-client.matrix.org' (Supports unix socket: /path/to/socket)
SYNCV3_DB Required. The postgres connection string: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
SYNCV3_SECRET Required. A secret to use to encrypt access tokens. Must remain the same for the lifetime of the database.
SYNCV3_BINDADDR Default: 0.0.0.0:8008. The interface and port to listen on.
SYNCV3_BINDADDR Default: 0.0.0.0:8008. The interface and port to listen on. (Supports unix socket: /path/to/socket)
SYNCV3_TLS_CERT Default: unset. Path to a certificate file to serve to HTTPS clients. Specifying this enables TLS on the bound address.
SYNCV3_TLS_KEY Default: unset. Path to a key file for the certificate. Must be provided along with the certificate file.
SYNCV3_PPROF Default: unset. The bind addr for pprof debugging e.g ':6060'. If not set, does not listen.
Expand Down
4 changes: 2 additions & 2 deletions cmd/syncv3/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ const (

var helpMsg = fmt.Sprintf(`
Environment var
%s Required. The destination homeserver to talk to (CS API HTTPS URL) e.g 'https://matrix-client.matrix.org'
%s Required. The destination homeserver to talk to (CS API HTTPS URL) e.g 'https://matrix-client.matrix.org' (Supports unix socket: /path/to/socket)
%s Required. The postgres connection string: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
%s Required. A secret to use to encrypt access tokens. Must remain the same for the lifetime of the database.
%s Default: 0.0.0.0:8008. The interface and port to listen on.
%s Default: 0.0.0.0:8008. The interface and port to listen on. (Supports unix socket: /path/to/socket)
%s Default: unset. Path to a certificate file to serve to HTTPS clients. Specifying this enables TLS on the bound address.
%s Default: unset. Path to a key file for the certificate. Must be provided along with the certificate file.
%s Default: unset. The bind addr for pprof debugging e.g ':6060'. If not set, does not listen.
Expand Down
26 changes: 26 additions & 0 deletions internal/util.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package internal

import (
"context"
"net"
"net/http"
"strings"
)

// Keys returns a slice containing copies of the keys of the given map, in no particular
// order.
func Keys[K comparable, V any](m map[K]V) []K {
Expand All @@ -12,3 +19,22 @@ func Keys[K comparable, V any](m map[K]V) []K {
}
return output
}

func IsUnixSocket(httpOrUnixStr string) bool {
return strings.HasPrefix(httpOrUnixStr, "/")
}

func UnixTransport(httpOrUnixStr string) *http.Transport {
return &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", httpOrUnixStr)
},
}
}

func GetBaseURL(httpOrUnixStr string) string {
if IsUnixSocket(httpOrUnixStr) {
return "http://unix"
}
return httpOrUnixStr
}
28 changes: 28 additions & 0 deletions internal/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,31 @@ func assertSlice(t *testing.T, got, want []string) {
t.Errorf("After sorting, got %v but expected %v", got, want)
}
}

func TestUnixSocket_True(t *testing.T) {
address := "/path/to/socket"
if !IsUnixSocket(address) {
t.Errorf("%s is socket", address)
}
}

func TestUnixSocket_False(t *testing.T) {
address := "localhost:8080"
if IsUnixSocket(address) {
t.Errorf("%s is not socket", address)
}
}

func TestGetBaseUrl_UnixSocket(t *testing.T) {
address := "/path/to/socket"
if GetBaseURL(address) != "http://unix" {
t.Errorf("%s is unix socket", address)
}
}

func TestGetBaseUrl_Http(t *testing.T) {
address := "localhost:8080"
if GetBaseURL(address) != "localhost:8080" {
t.Errorf("%s is not a unix socket", address)
}
}
30 changes: 18 additions & 12 deletions sync2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"github.com/matrix-org/sliding-sync/internal"
"io"
"net/http"
"net/url"
"time"
Expand Down Expand Up @@ -40,15 +41,20 @@ type HTTPClient struct {

func NewHTTPClient(shortTimeout, longTimeout time.Duration, destHomeServer string) *HTTPClient {
return &HTTPClient{
LongTimeoutClient: &http.Client{
Timeout: longTimeout,
Transport: otelhttp.NewTransport(http.DefaultTransport),
},
Client: &http.Client{
Timeout: shortTimeout,
Transport: otelhttp.NewTransport(http.DefaultTransport),
},
DestinationServer: destHomeServer,
LongTimeoutClient: newClient(longTimeout, destHomeServer),
Client: newClient(shortTimeout, destHomeServer),
DestinationServer: internal.GetBaseURL(destHomeServer),
}
}

func newClient(timeout time.Duration, destHomeServer string) *http.Client {
transport := http.DefaultTransport
if internal.IsUnixSocket(destHomeServer) {
transport = internal.UnixTransport(destHomeServer)
}
return &http.Client{
Timeout: timeout,
Transport: otelhttp.NewTransport(transport),
}
}

Expand All @@ -66,7 +72,7 @@ func (v *HTTPClient) Versions(ctx context.Context) ([]string, error) {
return nil, fmt.Errorf("/versions returned HTTP %d", res.StatusCode)
}
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -99,7 +105,7 @@ func (v *HTTPClient) WhoAmI(ctx context.Context, accessToken string) (string, st
return "", "", fmt.Errorf("/whoami returned HTTP %d", res.StatusCode)
}
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
body, err := io.ReadAll(res.Body)
if err != nil {
return "", "", err
}
Expand Down
36 changes: 31 additions & 5 deletions v3.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import (
"context"
"embed"
"encoding/json"
"errors"
"fmt"
"io/fs"
"net"
"net/http"
"os"
"strings"
Expand Down Expand Up @@ -216,12 +219,18 @@ func RunSyncV3Server(h http.Handler, bindAddr, destV2Server, tlsCert, tlsKey str

// Block forever
var err error
if tlsCert != "" && tlsKey != "" {
logger.Info().Msgf("listening TLS on %s", bindAddr)
err = http.ListenAndServeTLS(bindAddr, tlsCert, tlsKey, srv)
if internal.IsUnixSocket(bindAddr) {
logger.Info().Msgf("listening on unix socket %s", bindAddr)
listener := unixSocketListener(bindAddr)
err = http.Serve(listener, srv)
} else {
logger.Info().Msgf("listening on %s", bindAddr)
err = http.ListenAndServe(bindAddr, srv)
if tlsCert != "" && tlsKey != "" {
logger.Info().Msgf("listening TLS on %s", bindAddr)
err = http.ListenAndServeTLS(bindAddr, tlsCert, tlsKey, srv)
} else {
logger.Info().Msgf("listening on %s", bindAddr)
err = http.ListenAndServe(bindAddr, srv)
}
}
if err != nil {
sentry.CaptureException(err)
Expand All @@ -230,6 +239,23 @@ func RunSyncV3Server(h http.Handler, bindAddr, destV2Server, tlsCert, tlsKey str
}
}

func unixSocketListener(bindAddr string) net.Listener {
err := os.Remove(bindAddr)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
logger.Fatal().Err(err).Msg("failed to remove existing unix socket")
}
listener, err := net.Listen("unix", bindAddr)
if err != nil {
logger.Fatal().Err(err).Msg("failed to serve unix socket")
}
// TODO: safe default for now (rwxr-xr-x), could be extracted as env variable if needed
err = os.Chmod(bindAddr, 0755)
if err != nil {
logger.Fatal().Err(err).Msg("failed to set unix socket permissions")
}
return listener
}

type HandlerError struct {
StatusCode int
Err error
Expand Down

0 comments on commit 62d3798

Please sign in to comment.