Skip to content

Commit

Permalink
feat(ssh): add logs on ssh connection dialer
Browse files Browse the repository at this point in the history
  • Loading branch information
henrybarreto committed Sep 16, 2024
1 parent 3a2b47e commit e049ec2
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 18 deletions.
1 change: 1 addition & 0 deletions gateway/nginx/conf.d/shellhub.conf
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ server {
{{ end -}}
proxy_set_header X-Device-UID $device_uid;
proxy_set_header X-Tenant-ID $tenant_id;
proxy_set_header X-Request-ID $request_id;
proxy_http_version 1.1;
proxy_cache_bypass $http_upgrade;
proxy_redirect off;
Expand Down
21 changes: 17 additions & 4 deletions pkg/connman/connman.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"context"
"errors"
"net"
"os"
"strings"

"github.com/shellhub-io/shellhub/pkg/revdial"
"github.com/shellhub-io/shellhub/pkg/wsconnadapter"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)

var ErrNoConnection = errors.New("no connection")
Expand All @@ -27,12 +29,23 @@ func New() *ConnectionManager {
}

func (m *ConnectionManager) Set(key string, conn *wsconnadapter.Adapter, connPath string) {
dialer := revdial.NewDialer(conn, connPath)
parts := strings.Split(key, ":")
logger := (&log.Logger{
Out: os.Stderr,
Formatter: log.StandardLogger().Formatter,
Hooks: log.StandardLogger().Hooks,
Level: log.StandardLogger().Level,
}).WithFields(log.Fields{
"tenant": parts[0],
"device": parts[1],
})

dialer := revdial.NewDialer(logger, conn, connPath)

m.dialers.Store(key, dialer)

if size := m.dialers.Size(key); size > 1 {
logrus.WithFields(logrus.Fields{
log.WithFields(log.Fields{
"key": key,
"size": size,
}).Warning("Multiple connections stored for the same identifier.")
Expand Down Expand Up @@ -67,7 +80,7 @@ func (m *ConnectionManager) Dial(ctx context.Context, key string) (net.Conn, err
}

if size := m.dialers.Size(key); size > 1 {
logrus.WithFields(logrus.Fields{
log.WithFields(log.Fields{
"key": key,
"size": size,
}).Warning("Multiple connections found for the same identifier during reverse tunnel dialing.")
Expand Down
13 changes: 11 additions & 2 deletions pkg/httptunnel/httptunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import (
"io"
"net"
"net/http"
"strings"

"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/shellhub-io/shellhub/pkg/connman"
"github.com/shellhub-io/shellhub/pkg/revdial"
"github.com/shellhub-io/shellhub/pkg/wsconnadapter"
log "github.com/sirupsen/logrus"
)

var upgrader = websocket.Upgrader{
Expand Down Expand Up @@ -75,14 +77,21 @@ func (t *Tunnel) Router() http.Handler {
return c.String(http.StatusInternalServerError, err.Error())
}

id, err := t.ConnectionHandler(c.Request())
key, err := t.ConnectionHandler(c.Request())
if err != nil {
conn.Close()

return c.String(http.StatusBadRequest, err.Error())
}

t.connman.Set(id, wsconnadapter.New(conn), t.DialerPath)
parts := strings.Split(key, ":")
log.WithFields(log.Fields{
"request-id": c.Request().Header.Get("X-Request-ID"),
"tenant": parts[0],
"device": parts[1],
}).Debug("new ssh connection")

t.connman.Set(key, wsconnadapter.New(conn), t.DialerPath)

return nil
})
Expand Down
37 changes: 27 additions & 10 deletions pkg/revdial/revdial.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"strings"
Expand All @@ -34,11 +33,13 @@ import (
"github.com/gorilla/websocket"
"github.com/shellhub-io/shellhub/pkg/clock"
"github.com/shellhub-io/shellhub/pkg/wsconnadapter"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)

var ErrDialerClosed = errors.New("revdial.Dialer closed")
var ErrDialerTimedout = errors.New("revdial.Dialer timedout")
var (
ErrDialerClosed = errors.New("revdial.Dialer closed")
ErrDialerTimedout = errors.New("revdial.Dialer timedout")
)

// dialerUniqParam is the parameter name of the GET URL form value
// containing the Dialer's random unique ID.
Expand All @@ -59,19 +60,18 @@ type Dialer struct {
connReady chan bool
donec chan struct{}
closeOnce sync.Once
logger *log.Entry
}

var (
dialers = sync.Map{}
)
var dialers = sync.Map{}

// NewDialer returns the side of the connection which will initiate
// new connections. This will typically be the side which did the HTTP
// Hijack. The connection is (typically) the hijacked HTTP client
// connection. The connPath is the HTTP path and optional query (but
// without scheme or host) on the dialer where the ConnHandler is
// mounted.
func NewDialer(c net.Conn, connPath string) *Dialer {
func NewDialer(logger *log.Entry, c net.Conn, connPath string) *Dialer {
d := &Dialer{
path: connPath,
uniqID: newUniqID(),
Expand All @@ -80,6 +80,7 @@ func NewDialer(c net.Conn, connPath string) *Dialer {
connReady: make(chan bool),
incomingConn: make(chan net.Conn),
pickupFailed: make(chan error),
logger: logger,
}

join := "?"
Expand Down Expand Up @@ -121,6 +122,8 @@ func (d *Dialer) Close() error {
}

func (d *Dialer) close() {
d.logger.Debug("dialer connection closed")

d.unregister()
d.conn.Close()
d.donec <- struct{}{}
Expand Down Expand Up @@ -165,21 +168,24 @@ func (d *Dialer) serve() error {

go func() {
defer d.Close()
defer d.logger.Debug("dialer serve done due reader error")

br := bufio.NewReader(d.conn)
for {
line, err := br.ReadSlice('\n')
if err != nil {
d.logger.WithError(err).Debug("failed to read the agent's command")

unexpectedError := websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure)
if !errors.Is(err, net.ErrClosed) && unexpectedError {
logrus.WithError(err).Error("revdial.Dialer failed to read")
d.logger.WithError(err).Error("revdial.Dialer failed to read")
}

return
}
var msg controlMsg
if err := json.Unmarshal(line, &msg); err != nil {
log.Printf("revdial.Dialer read invalid JSON: %q: %v", line, err)
d.logger.Printf("revdial.Dialer read invalid JSON: %q: %v", line, err)

return
}
Expand All @@ -190,16 +196,21 @@ func (d *Dialer) serve() error {
select {
case d.pickupFailed <- err:
case <-d.donec:
d.logger.WithError(err).Debug("failed to pick-up connection")

return
}
case "keep-alive":
default:
// Ignore unknown messages
log.WithField("message", msg.Command).Debug("unknown message received")
}
}
}()
for {
if err := d.sendMessage(controlMsg{Command: "keep-alive"}); err != nil {
d.logger.WithError(err).Debug("failed to send keep-alive message to device")

return err
}

Expand All @@ -213,6 +224,8 @@ func (d *Dialer) serve() error {
Command: "conn-ready",
ConnPath: d.pickupPath,
}); err != nil {
d.logger.WithError(err).Debug("failed to send conn-ready message to device")

return err
}
case <-d.donec:
Expand All @@ -225,13 +238,17 @@ func (d *Dialer) serve() error {

func (d *Dialer) sendMessage(m controlMsg) error {
if err := d.conn.SetWriteDeadline(clock.Now().Add(10 * time.Second)); err != nil {
d.logger.WithError(err).Debug("failed to set the write dead line to device")

return err
}

j, _ := json.Marshal(m)
j = append(j, '\n')

if _, err := d.conn.Write(j); err != nil {
d.logger.WithError(err).Debug("failed to write on the connection")

return err
}

Expand Down
6 changes: 4 additions & 2 deletions pkg/wsconnadapter/wsconnadapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"time"

"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)

// an adapter for representing WebSocket connection as a net.Conn
Expand Down Expand Up @@ -71,9 +71,11 @@ func (a *Adapter) Ping() chan bool {
select {
case <-ticker.C:
if err := a.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil {
logrus.WithError(err).Error("Failed to write ping message")
log.WithError(err).Error("Failed to write ping message")
}
case <-a.stopPingCh:
log.Debug("Stop ping message received")

return
}
}
Expand Down

0 comments on commit e049ec2

Please sign in to comment.