package main import ( "database/sql" "fmt" "io" "io/ioutil" "os" "os/signal" "os/user" "runtime" "strconv" "strings" "syscall" "time" "github.com/manifoldco/promptui" _ "github.com/mattn/go-sqlite3" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/terminal" ) var ( DarwinAuthFileFmt = "/Users/%s/.ssh/id_rsa" LinuxAuthFileFmt = "/home/%s/.ssh/id_rsa" AuthFileFmt = LinuxAuthFileFmt DatabaseFile = "/usr/local/jumpserver/jumpserver.db" QuitSelect = "quit" ) type SSHTerminal struct { Session *ssh.Session exitMsg string stdout io.Reader stdin io.Writer stderr io.Reader } // ssh server type HostServer struct { Name string Host string Port int } // SSHUser ssh user type SSHUser struct { Username string Password string IdentityFile string } func init() { if runtime.GOOS == "linux" { AuthFileFmt = LinuxAuthFileFmt } else if runtime.GOOS == "darwin" { AuthFileFmt = DarwinAuthFileFmt } else { fmt.Printf("program not support os(%s)\n", runtime.GOOS) } } func (t *SSHTerminal) updateTerminalSize() { go func() { // 窗口大小改变 sigwinchCh := make(chan os.Signal, 1) signal.Notify(sigwinchCh, syscall.SIGWINCH) fd := int(os.Stdin.Fd()) termWidth, termHeight, err := terminal.GetSize(fd) if err != nil { fmt.Println(err) } for { select { case sigwinch := <-sigwinchCh: if sigwinch == nil { return } currTermWidth, currTermHeight, err := terminal.GetSize(fd) // 窗口没有发生改变 if currTermHeight == termHeight && currTermWidth == termWidth { continue } t.Session.WindowChange(currTermHeight, currTermWidth) if err != nil { fmt.Printf("Unable to send window-change reqest: %s.", err) continue } termWidth, termHeight = currTermWidth, currTermHeight } } }() } func (t *SSHTerminal) interactiveSession() error { defer func() { if t.exitMsg == "" { fmt.Fprintln(os.Stdout, "the connection was closed on the remote side on ", time.Now().Format(time.RFC822)) } else { fmt.Fprintln(os.Stdout, t.exitMsg) } }() fd := int(os.Stdin.Fd()) state, err := terminal.MakeRaw(fd) if err != nil { return err } defer terminal.Restore(fd, state) termWidth, termHeight, err := terminal.GetSize(fd) if err != nil { return err } termType := os.Getenv("TERM") if termType == "" { termType = "xterm-256color" } err = t.Session.RequestPty(termType, termHeight, termWidth, ssh.TerminalModes{ ssh.ECHO: 1, ssh.TTY_OP_ISPEED: 14400, ssh.TTY_OP_OSPEED: 14400, }) if err != nil { return err } t.updateTerminalSize() t.stdin, err = t.Session.StdinPipe() if err != nil { return err } t.stdout, err = t.Session.StdoutPipe() if err != nil { return err } t.stderr, err = t.Session.StderrPipe() go io.Copy(os.Stderr, t.stderr) go io.Copy(os.Stdout, t.stdout) go func() { buf := make([]byte, 128) for { n, err := os.Stdin.Read(buf) if err != nil { fmt.Println(err) return } if n > 0 { _, err = t.stdin.Write(buf[:n]) if err != nil { fmt.Println(err) t.exitMsg = err.Error() return } } } }() err = t.Session.Shell() if err != nil { return err } err = t.Session.Wait() if err != nil { return err } return nil } func NewTerminal(server *HostServer, sshUser *SSHUser) (err error) { key, err := ioutil.ReadFile(fmt.Sprintf(AuthFileFmt, sshUser.Username)) if err != nil { return } signer, err := ssh.ParsePrivateKey(key) if err != nil { return } sshConfig := &ssh.ClientConfig{ User: sshUser.Username, Auth: []ssh.AuthMethod{ ssh.Password(sshUser.Password), ssh.PublicKeys(signer), }, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", server.Host, server.Port), sshConfig) if err != nil { fmt.Println(err) } defer client.Close() session, err := client.NewSession() if err != nil { return err } defer session.Close() s := SSHTerminal{ Session: session, } return s.interactiveSession() } func handlerSignal() { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT) for { s := <-c switch s { case syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT: return case syscall.SIGHUP: default: return } } } func main() { // 获取当前用户 userInfo, err := user.Current() if err != nil { fmt.Println("user.Current failed!") return } // 连接数据库 db, err := sql.Open("sqlite3", DatabaseFile) if err != nil { fmt.Printf("open database failed! err=%v", err) return } // 查询当前用户有访问权限的主机 rows, err := db.Query(fmt.Sprintf("SELECT name, ip, port FROM hosts where isdelete=0 and name in(select hostname from hostuser where username='%s');", userInfo.Username)) if err != nil { fmt.Printf("db.Query failed! err=%v", err) return } // ui展示列表 menuLabels := make([]string, 0) menuLabels = append(menuLabels, QuitSelect) // 退出 for rows.Next() { var ( name string ip string port int ) err = rows.Scan(&name, &ip, &port) if err != nil { fmt.Printf("rows.Scan failed! err=%v", err) continue } menuLabels = append(menuLabels, fmt.Sprintf("%s:%s:%d", name, ip, port)) } rows.Close() db.Close() // 选项列表 prompt := promptui.Select{ Label: "Select Host (quit for ctrl+c or select quit)", Items: menuLabels, Size: len(menuLabels), } for { _, selectLabel, err := prompt.Run() if err != nil { fmt.Printf("Prompt failed %v\n", err) return } fmt.Printf("You choose %q\n", selectLabel) if selectLabel == QuitSelect { return } // 获取选中信息 hostItems := strings.Split(selectLabel, ":") if len(hostItems) != 3 { fmt.Printf("invalid ssh host: %s", selectLabel) continue } port, err := strconv.ParseInt(hostItems[2], 10, 32) if err != nil { fmt.Printf("invalid ssh port: %s", selectLabel[2]) continue } server := &HostServer{Name: hostItems[0], Host: hostItems[1], Port: int(port)} user := &SSHUser{Username: userInfo.Username, IdentityFile: fmt.Sprintf(AuthFileFmt, userInfo.Username)} // 新建一个终端 err = NewTerminal(server, user) if err != nil { fmt.Printf("NewTerminal err=%v\n", err) } } }