Skip to content

Commit

Permalink
refactor(agent,pkg): rename server package to SSH
Browse files Browse the repository at this point in the history
  • Loading branch information
henrybarreto committed Sep 5, 2024
1 parent 07d9ed4 commit 54fb229
Show file tree
Hide file tree
Showing 30 changed files with 225 additions and 302 deletions.
11 changes: 6 additions & 5 deletions agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import (

"github.com/Masterminds/semver"
"github.com/shellhub-io/shellhub/pkg/agent"
"github.com/shellhub-io/shellhub/pkg/agent/connector"
"github.com/shellhub-io/shellhub/pkg/agent/pkg/selfupdater"
"github.com/shellhub-io/shellhub/pkg/agent/server/modes/host/command"
"github.com/shellhub-io/shellhub/pkg/agent/ssh"
"github.com/shellhub-io/shellhub/pkg/agent/ssh/connector"
"github.com/shellhub-io/shellhub/pkg/agent/ssh/modes/host/command"
"github.com/shellhub-io/shellhub/pkg/envs"
"github.com/shellhub-io/shellhub/pkg/loglevel"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -162,14 +163,14 @@ func main() {
}()
}

if err := ag.Listen(ctx); err != nil {
if err := ag.ListenSSH(ctx); err != nil {
log.WithError(err).WithFields(log.Fields{
"version": AgentVersion,
"mode": mode,
"tenant_id": cfg.TenantID,
"server_address": cfg.ServerAddress,
"preferred_hostname": cfg.PreferredHostname,
}).Fatal("Failed to listen for connections")
}).Fatal("Failed to listen for SSH connections")
}

log.WithFields(log.Fields{
Expand Down Expand Up @@ -266,7 +267,7 @@ func main() {
Long: `Starts the SFTP server. This command is used internally by the agent and should not be used directly.
It is initialized by the agent when a new SFTP session is created.`,
Run: func(cmd *cobra.Command, args []string) {
agent.NewSFTPServer(command.SFTPServerMode(args[0]))
ssh.NewSFTPServer(command.SFTPServerMode(args[0]))
},
})

Expand Down
278 changes: 10 additions & 268 deletions pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
// panic(err)
// }
//
// ag.Listen(ctx)
// ag.ListenSSH(ctx)
// }
//
// [ShellHub Agent]: https://github.com/shellhub-io/shellhub/tree/master/agent
Expand All @@ -41,24 +41,15 @@ package agent
import (
"context"
"crypto/rsa"
"io"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"runtime"
"strings"
"sync/atomic"
"time"

"github.com/Masterminds/semver"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/shellhub-io/shellhub/pkg/agent/pkg/keygen"
"github.com/shellhub-io/shellhub/pkg/agent/pkg/sysinfo"
"github.com/shellhub-io/shellhub/pkg/agent/pkg/tunnel"
"github.com/shellhub-io/shellhub/pkg/agent/server"
"github.com/shellhub-io/shellhub/pkg/agent/ssh"
"github.com/shellhub-io/shellhub/pkg/api/client"
"github.com/shellhub-io/shellhub/pkg/envs"
"github.com/shellhub-io/shellhub/pkg/models"
Expand Down Expand Up @@ -169,12 +160,9 @@ type Agent struct {
Identity *models.DeviceIdentity
Info *models.DeviceInfo
authData *models.DeviceAuthResponse
cli client.Client
serverInfo *models.Info
server *server.Server
tunnel *tunnel.Tunnel
listening chan bool
closed atomic.Bool
cli client.Client
ssh *ssh.SSH
mode Mode
}

Expand Down Expand Up @@ -264,8 +252,6 @@ func (a *Agent) Initialize() error {
return errors.Wrap(err, "failed to authorize device")
}

a.closed.Store(false)

return nil
}

Expand Down Expand Up @@ -356,263 +342,19 @@ func (a *Agent) authorize() error {
return err
}

func (a *Agent) isClosed() bool {
return a.closed.Load()
}

// Close closes the ShellHub Agent's listening, stoping it from receive new connection requests.
func (a *Agent) Close() error {
a.closed.Store(true)

return a.tunnel.Close()
}

func connHandler(serv *server.Server) func(c echo.Context) error {
return func(c echo.Context) error {
hj, ok := c.Response().Writer.(http.Hijacker)
if !ok {
return c.String(http.StatusInternalServerError, "webserver doesn't support hijacking")
}

conn, _, err := hj.Hijack()
if err != nil {
return c.String(http.StatusInternalServerError, "failed to hijack connection")
}

id := c.Param("id")
httpConn := c.Request().Context().Value("http-conn").(net.Conn)
serv.Sessions.Store(id, httpConn)
serv.HandleConn(httpConn)

conn.Close()

return nil
}
return a.ssh.Close()
}

func httpHandler() func(c echo.Context) error {
return func(c echo.Context) error {
replyError := func(err error, msg string, code int) error {
log.WithError(err).WithFields(log.Fields{
"remote": c.Request().RemoteAddr,
"namespace": c.Request().Header.Get("X-Namespace"),
"path": c.Request().Header.Get("X-Path"),
"version": AgentVersion,
}).Error(msg)

return c.String(code, msg)
}

in, err := net.Dial("tcp", ":80")
if err != nil {
return replyError(err, "failed to connect to HTTP server on device", http.StatusInternalServerError)
}

defer in.Close()

url, err := url.Parse(c.Request().Header.Get("X-Path"))
if err != nil {
return replyError(err, "failed to parse URL", http.StatusInternalServerError)
}

c.Request().URL.Scheme = "http"
c.Request().URL = url

if err := c.Request().Write(in); err != nil {
return replyError(err, "failed to write request to the server on device", http.StatusInternalServerError)
}

out, _, err := c.Response().Hijack()
if err != nil {
return replyError(err, "failed to hijack connection", http.StatusInternalServerError)
}

defer out.Close() // nolint:errcheck
// ListenSSH creates the SSH server and listening for connections.
func (a *Agent) ListenSSH(ctx context.Context) error {
a.ssh = ssh.NewSSH(a.cli, a.authData.Token)

if _, err := io.Copy(out, in); err != nil {
return replyError(err, "failed to copy response from device service to client", http.StatusInternalServerError)
}

return nil
}
}

func closeHandler(a *Agent, serv *server.Server) func(c echo.Context) error {
return func(c echo.Context) error {
id := c.Param("id")
serv.CloseSession(id)

log.WithFields(
log.Fields{
"id": id,
"version": AgentVersion,
"tenant_id": a.authData.Namespace,
"server_address": a.config.ServerAddress,
},
).Info("A tunnel connection was closed")

return nil
}
}

// Listen creates the SSH server and listening for connections.
func (a *Agent) Listen(ctx context.Context) error {
// TODO: Don't create the SSH server from this function, as it seems to be out its own context.
a.mode.Serve(a)

a.tunnel = tunnel.NewBuilder().
WithConnHandler(connHandler(a.server)).
WithCloseHandler(closeHandler(a, a.server)).
WithHTTPHandler(httpHandler()).
Build()

go a.ping(ctx, AgentPingDefaultInterval) //nolint:errcheck

ctx, cancel := context.WithCancel(ctx)
go func() {
for {
if a.isClosed() {
log.WithFields(log.Fields{
"version": AgentVersion,
"tenant_id": a.authData.Namespace,
"server_address": a.config.ServerAddress,
}).Info("Stopped listening for connections")

cancel()

return
}

namespace := a.authData.Namespace
tenantName := a.authData.Name
sshEndpoint := a.serverInfo.Endpoints.SSH

sshid := strings.NewReplacer(
"{namespace}", namespace,
"{tenantName}", tenantName,
"{sshEndpoint}", strings.Split(sshEndpoint, ":")[0],
).Replace("{namespace}.{tenantName}@{sshEndpoint}")

listener, err := a.cli.NewReverseListener(ctx, a.authData.Token, "/ssh/connection")
if err != nil {
log.WithError(err).WithFields(log.Fields{
"version": AgentVersion,
"tenant_id": a.authData.Namespace,
"server_address": a.config.ServerAddress,
"ssh_server": sshEndpoint,
"sshid": sshid,
}).Error("Failed to connect to server through reverse tunnel. Retry in 10 seconds")
time.Sleep(time.Second * 10)

continue
}

log.WithFields(log.Fields{
"namespace": namespace,
"hostname": tenantName,
"server_address": a.config.ServerAddress,
"ssh_server": sshEndpoint,
"sshid": sshid,
}).Info("Server connection established")

a.listening <- true

{
// NOTE: Tunnel'll only realize that it lost its connection to the ShellHub SSH when the next
// "keep-alive" connection fails. As a result, it will take this interval to reconnect to its server.
err := a.tunnel.Listen(listener)

log.WithError(err).WithFields(log.Fields{
"namespace": namespace,
"hostname": tenantName,
"server_address": a.config.ServerAddress,
"ssh_server": sshEndpoint,
"sshid": sshid,
}).Info("Tunnel listener closed")

listener.Close() // nolint:errcheck
}

a.listening <- false
}
}()

<-ctx.Done()

return a.Close()
}

// AgentPingDefaultInterval is the default time interval between ping on agent.
const AgentPingDefaultInterval = 10 * time.Minute

// ping sends an authorization request to the ShellHub server at each interval.
// A random value between 10 and [config.MaxRetryConnectionTimeout] seconds is added to the interval
// each time the ticker is executed.
//
// Ping only sends requests to the server if the agent is listening for connections. If the agent is not
// listening, the ping process will be stopped. When the interval is 0, the default value is 10 minutes.
func (a *Agent) ping(ctx context.Context, interval time.Duration) error {
a.listening = make(chan bool)

if interval == 0 {
interval = AgentPingDefaultInterval
}

<-a.listening // NOTE: wait for the first connection to start to ping the server.
ticker := time.NewTicker(interval)

for {
if a.isClosed() {
return nil
}

select {
case <-ctx.Done():
log.WithFields(log.Fields{
"version": AgentVersion,
"tenant_id": a.authData.Namespace,
"server_address": a.config.ServerAddress,
}).Debug("stopped pinging server due to context cancellation")

return nil
case ok := <-a.listening:
if ok {
log.WithFields(log.Fields{
"version": AgentVersion,
"tenant_id": a.authData.Namespace,
"server_address": a.config.ServerAddress,
"timestamp": time.Now(),
}).Debug("Starting the ping interval to server")

ticker.Reset(interval)
} else {
log.WithFields(log.Fields{
"version": AgentVersion,
"tenant_id": a.authData.Namespace,
"server_address": a.config.ServerAddress,
"timestamp": time.Now(),
}).Debug("Stopped pinging server due listener status")

ticker.Stop()
}
case <-ticker.C:
if err := a.authorize(); err != nil {
a.server.SetDeviceName(a.authData.Name)
}

log.WithFields(log.Fields{
"version": AgentVersion,
"tenant_id": a.authData.Namespace,
"server_address": a.config.ServerAddress,
"name": a.authData.Name,
"hostname": a.config.PreferredHostname,
"identity": a.config.PreferredIdentity,
"timestamp": time.Now(),
}).Info("Ping")

randTimeout := time.Duration(rand.Intn(a.config.MaxRetryConnectionTimeout-10)+10) * time.Second
ticker.Reset(interval + randTimeout)
}
}
return a.ssh.Listen(ctx)
}

// CheckUpdate gets the ShellHub's server version.
Expand Down
Loading

0 comments on commit 54fb229

Please sign in to comment.