|
|
package main
import ( "database/sql" "fmt" "io" "io/ioutil" "os" "os/signal" "os/user" "runtime" "strconv" "strings" "syscall" "time"
"github.com/manifoldco/promptui" _ "github.com/mattn/go-sqlite3" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/terminal" )
var ( DarwinAuthFileFmt = "/Users/%s/.ssh/id_rsa" LinuxAuthFileFmt = "/home/%s/.ssh/id_rsa" AuthFileFmt = LinuxAuthFileFmt DatabaseFile = "/usr/local/jumpserver/jumpserver.db" QuitSelect = "quit" )
type SSHTerminal struct { Session *ssh.Session exitMsg string stdout io.Reader stdin io.Writer stderr io.Reader }
// ssh server
type HostServer struct { Name string Host string Port int }
// SSHUser ssh user
type SSHUser struct { Username string Password string IdentityFile string }
func init() { if runtime.GOOS == "linux" { AuthFileFmt = LinuxAuthFileFmt } else if runtime.GOOS == "darwin" { AuthFileFmt = DarwinAuthFileFmt } else { fmt.Printf("program not support os(%s)\n", runtime.GOOS) } }
func (t *SSHTerminal) updateTerminalSize() { go func() { // 窗口大小改变
sigwinchCh := make(chan os.Signal, 1) signal.Notify(sigwinchCh, syscall.SIGWINCH)
fd := int(os.Stdin.Fd()) termWidth, termHeight, err := terminal.GetSize(fd) if err != nil { fmt.Println(err) }
for { select { case sigwinch := <-sigwinchCh: if sigwinch == nil { return } currTermWidth, currTermHeight, err := terminal.GetSize(fd)
// 窗口没有发生改变
if currTermHeight == termHeight && currTermWidth == termWidth { continue }
t.Session.WindowChange(currTermHeight, currTermWidth) if err != nil { fmt.Printf("Unable to send window-change reqest: %s.", err) continue }
termWidth, termHeight = currTermWidth, currTermHeight
} } }()
}
func (t *SSHTerminal) interactiveSession() error { defer func() { if t.exitMsg == "" { fmt.Fprintln(os.Stdout, "the connection was closed on the remote side on ", time.Now().Format(time.RFC822)) } else { fmt.Fprintln(os.Stdout, t.exitMsg) } }()
fd := int(os.Stdin.Fd()) state, err := terminal.MakeRaw(fd) if err != nil { return err } defer terminal.Restore(fd, state)
termWidth, termHeight, err := terminal.GetSize(fd) if err != nil { return err }
termType := os.Getenv("TERM") if termType == "" { termType = "xterm-256color" }
err = t.Session.RequestPty(termType, termHeight, termWidth, ssh.TerminalModes{ ssh.ECHO: 1, ssh.TTY_OP_ISPEED: 14400, ssh.TTY_OP_OSPEED: 14400, }) if err != nil { return err }
t.updateTerminalSize()
t.stdin, err = t.Session.StdinPipe() if err != nil { return err } t.stdout, err = t.Session.StdoutPipe() if err != nil { return err } t.stderr, err = t.Session.StderrPipe()
go io.Copy(os.Stderr, t.stderr) go io.Copy(os.Stdout, t.stdout) go func() { buf := make([]byte, 128) for { n, err := os.Stdin.Read(buf) if err != nil { fmt.Println(err) return } if n > 0 { _, err = t.stdin.Write(buf[:n]) if err != nil { fmt.Println(err) t.exitMsg = err.Error() return } } } }()
err = t.Session.Shell() if err != nil { return err } err = t.Session.Wait() if err != nil { return err } return nil }
func NewTerminal(server *HostServer, sshUser *SSHUser) (err error) { key, err := ioutil.ReadFile(fmt.Sprintf(AuthFileFmt, sshUser.Username)) if err != nil { return } signer, err := ssh.ParsePrivateKey(key) if err != nil { return }
sshConfig := &ssh.ClientConfig{ User: sshUser.Username, Auth: []ssh.AuthMethod{ ssh.Password(sshUser.Password), ssh.PublicKeys(signer), }, HostKeyCallback: ssh.InsecureIgnoreHostKey(), }
client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", server.Host, server.Port), sshConfig) if err != nil { fmt.Println(err) } defer client.Close()
session, err := client.NewSession() if err != nil { return err } defer session.Close()
s := SSHTerminal{ Session: session, } return s.interactiveSession() }
func handlerSignal() { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT) for { s := <-c switch s { case syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT: return case syscall.SIGHUP: default: return } } }
func main() { // 获取当前用户
userInfo, err := user.Current() if err != nil { fmt.Println("user.Current failed!") return }
// 连接数据库
db, err := sql.Open("sqlite3", DatabaseFile) if err != nil { fmt.Printf("open database failed! err=%v", err) return }
// 查询当前用户有访问权限的主机
rows, err := db.Query(fmt.Sprintf("SELECT name, ip, port FROM hosts where isdelete=0 and name in(select hostname from hostuser where username='%s');", userInfo.Username)) if err != nil { fmt.Printf("db.Query failed! err=%v", err) return }
// ui展示列表
menuLabels := make([]string, 0) menuLabels = append(menuLabels, QuitSelect) // 退出
for rows.Next() { var ( name string ip string port int ) err = rows.Scan(&name, &ip, &port) if err != nil { fmt.Printf("rows.Scan failed! err=%v", err) continue }
menuLabels = append(menuLabels, fmt.Sprintf("%s:%s:%d", name, ip, port)) } rows.Close() db.Close()
// 选项列表
prompt := promptui.Select{ Label: "Select Host (quit for ctrl+c or select quit)", Items: menuLabels, Size: len(menuLabels), }
for { _, selectLabel, err := prompt.Run()
if err != nil { fmt.Printf("Prompt failed %v\n", err) return }
fmt.Printf("You choose %q\n", selectLabel) if selectLabel == QuitSelect { return }
// 获取选中信息
hostItems := strings.Split(selectLabel, ":") if len(hostItems) != 3 { fmt.Printf("invalid ssh host: %s", selectLabel) continue } port, err := strconv.ParseInt(hostItems[2], 10, 32) if err != nil { fmt.Printf("invalid ssh port: %s", selectLabel[2]) continue } server := &HostServer{Name: hostItems[0], Host: hostItems[1], Port: int(port)} user := &SSHUser{Username: userInfo.Username, IdentityFile: fmt.Sprintf(AuthFileFmt, userInfo.Username)}
// 新建一个终端
err = NewTerminal(server, user) if err != nil { fmt.Printf("NewTerminal err=%v\n", err) } } }
|