You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
309 lines
9.2 KiB
309 lines
9.2 KiB
#!/usr/bin/env python3
|
|
# -*- coding: UTF-8 -*-
|
|
|
|
import paramiko
|
|
import os
|
|
import sys
|
|
import curses
|
|
import termios
|
|
import tty
|
|
import socket
|
|
import signal
|
|
import getpass
|
|
import platform
|
|
import sqlite3
|
|
|
|
KEYS_ENTER = (curses.KEY_ENTER, ord('\n'), ord('\r'))
|
|
KEYS_UP = (curses.KEY_UP, ord('k'))
|
|
KEYS_DOWN = (curses.KEY_DOWN, ord('j'))
|
|
KEYS_SELECT = (curses.KEY_RIGHT, ord(' '))
|
|
|
|
# 数据库文件 default
|
|
gSqlite3File = "/usr/local/jumpserver/jumpserver.db"
|
|
|
|
# ssh_private_path ssh私钥路径
|
|
system_type = platform.system()
|
|
if system_type == "Darwin":
|
|
ssh_private_path = "/Users/%s/.ssh/id_rsa"
|
|
elif system_type == "Linux":
|
|
ssh_private_path = "/home/%s/.ssh/id_rsa"
|
|
else:
|
|
print("Don't support your system(%s)" % system_type)
|
|
exit(1)
|
|
|
|
def init():
|
|
global gSqlite3File
|
|
jumpDb = os.environ.get("JUMPDB")
|
|
if jumpDb is not None:
|
|
gSqlite3File = jumpDb
|
|
else:
|
|
print("ERROR: Hadn't set environ var 'JUMPDB' for jump database! mod 755。Default use %s" % gSqlite3File)
|
|
# exit(1)
|
|
|
|
class Picker(object):
|
|
def __init__(self, options, title=None, indicator='*', default_index=0, multiselect=False, multi_select=False,
|
|
min_selection_count=0, options_map_func=None):
|
|
|
|
if len(options) == 0:
|
|
raise ValueError('options should not be an empty list')
|
|
|
|
self.options = options
|
|
self.title = title
|
|
self.indicator = indicator
|
|
self.multiselect = multiselect or multi_select
|
|
self.min_selection_count = min_selection_count
|
|
self.options_map_func = options_map_func
|
|
self.all_selected = []
|
|
|
|
if default_index >= len(options):
|
|
raise ValueError('default_index should be less than the length of options')
|
|
|
|
if multiselect and min_selection_count > len(options):
|
|
raise ValueError(
|
|
'min_selection_count is bigger than the available options, you will not be able to make any selection')
|
|
|
|
if options_map_func is not None and not callable(options_map_func):
|
|
raise ValueError('options_map_func must be a callable function')
|
|
|
|
self.index = default_index
|
|
self.custom_handlers = {}
|
|
|
|
def register_custom_handler(self, key, func):
|
|
self.custom_handlers[key] = func
|
|
|
|
def move_up(self):
|
|
self.index -= 1
|
|
if self.index < 0:
|
|
self.index = len(self.options) - 1
|
|
|
|
def move_down(self):
|
|
self.index += 1
|
|
if self.index >= len(self.options):
|
|
self.index = 0
|
|
|
|
def mark_index(self):
|
|
if self.multiselect:
|
|
if self.index in self.all_selected:
|
|
self.all_selected.remove(self.index)
|
|
else:
|
|
self.all_selected.append(self.index)
|
|
|
|
def get_selected(self):
|
|
"""return the current selected option as a tuple: (option, index)
|
|
or as a list of tuples (in case multiselect==True)
|
|
"""
|
|
if self.multiselect:
|
|
return_tuples = []
|
|
for selected in self.all_selected:
|
|
return_tuples.append((self.options[selected], selected))
|
|
return return_tuples
|
|
else:
|
|
return self.options[self.index], self.index
|
|
|
|
def get_title_lines(self):
|
|
if self.title:
|
|
return self.title.split('\n') + ['']
|
|
return []
|
|
|
|
def get_option_lines(self):
|
|
lines = []
|
|
for index, option in enumerate(self.options):
|
|
# pass the option through the options map of one was passed in
|
|
if self.options_map_func:
|
|
option = self.options_map_func(option)
|
|
|
|
if index == self.index:
|
|
prefix = self.indicator
|
|
else:
|
|
prefix = len(self.indicator) * ' '
|
|
|
|
if self.multiselect and index in self.all_selected:
|
|
format = curses.color_pair(1)
|
|
line = ('{0} {1}'.format(prefix, option), format)
|
|
else:
|
|
line = '{0} {1}'.format(prefix, option)
|
|
lines.append(line)
|
|
|
|
return lines
|
|
|
|
def get_lines(self):
|
|
title_lines = self.get_title_lines()
|
|
option_lines = self.get_option_lines()
|
|
lines = title_lines + option_lines
|
|
current_line = self.index + len(title_lines) + 1
|
|
return lines, current_line
|
|
|
|
def draw(self):
|
|
"""draw the curses ui on the screen, handle scroll if needed"""
|
|
self.screen.clear()
|
|
|
|
x, y = 1, 1 # start point
|
|
max_y, max_x = self.screen.getmaxyx()
|
|
max_rows = max_y - y # the max rows we can draw
|
|
|
|
lines, current_line = self.get_lines()
|
|
|
|
# calculate how many lines we should scroll, relative to the top
|
|
scroll_top = getattr(self, 'scroll_top', 0)
|
|
if current_line <= scroll_top:
|
|
scroll_top = 0
|
|
elif current_line - scroll_top > max_rows:
|
|
scroll_top = current_line - max_rows
|
|
self.scroll_top = scroll_top
|
|
|
|
lines_to_draw = lines[scroll_top:scroll_top + max_rows]
|
|
|
|
for line in lines_to_draw:
|
|
if type(line) is tuple:
|
|
self.screen.addnstr(y, x, line[0], max_x - 2, line[1])
|
|
else:
|
|
self.screen.addnstr(y, x, line, max_x - 2)
|
|
y += 1
|
|
|
|
self.screen.refresh()
|
|
|
|
def run_loop(self):
|
|
while True:
|
|
self.draw()
|
|
c = self.screen.getch()
|
|
if c in KEYS_UP:
|
|
self.move_up()
|
|
elif c in KEYS_DOWN:
|
|
self.move_down()
|
|
elif c in KEYS_ENTER:
|
|
if self.multiselect and len(self.all_selected) < self.min_selection_count:
|
|
continue
|
|
return self.get_selected()
|
|
elif c in KEYS_SELECT and self.multiselect:
|
|
self.mark_index()
|
|
elif c in self.custom_handlers:
|
|
ret = self.custom_handlers[c](self)
|
|
if ret:
|
|
return ret
|
|
elif c == ord('q'):
|
|
exit(0)
|
|
|
|
def config_curses(self):
|
|
try:
|
|
# use the default colors of the terminal
|
|
curses.use_default_colors()
|
|
# hide the cursor
|
|
curses.curs_set(0)
|
|
# add some color for multi_select
|
|
# @todo make colors configurable
|
|
curses.init_pair(1, curses.COLOR_GREEN, curses.COLOR_WHITE)
|
|
except:
|
|
# Curses failed to initialize color support, eg. when TERM=vt100
|
|
curses.initscr()
|
|
|
|
def _start(self, screen):
|
|
self.screen = screen
|
|
self.config_curses()
|
|
return self.run_loop()
|
|
|
|
def start(self):
|
|
return curses.wrapper(self._start)
|
|
|
|
|
|
def updateWindowHandler(signum):
|
|
if signum == signal.SIGWINCH:
|
|
width, height = os.get_terminal_size()
|
|
os.terminal_size((width, height))
|
|
|
|
|
|
def posix_shell(chan):
|
|
import select
|
|
oldtty = termios.tcgetattr(sys.stdin)
|
|
signal.signal(signal.SIGWINCH, updateWindowHandler)
|
|
|
|
try:
|
|
tty.setraw(sys.stdin.fileno(), termios.TCIOFLUSH)
|
|
# tty.setcbreak(sys.stdin.fileno())
|
|
chan.settimeout(0.0)
|
|
while True:
|
|
r, w, e = select.select([chan, sys.stdin], [], [])
|
|
if chan in r:
|
|
try:
|
|
data = str(chan.recv(1024), encoding='utf-8')
|
|
if len(data) == 0:
|
|
break
|
|
sys.stdout.write(data)
|
|
sys.stdout.flush()
|
|
except socket.timeout:
|
|
pass
|
|
if sys.stdin in r:
|
|
ch = sys.stdin.read(1)
|
|
if len(ch) == 0:
|
|
break
|
|
chan.send(ch)
|
|
finally:
|
|
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, oldtty)
|
|
|
|
|
|
def NewTerminal(host, port, user):
|
|
# 建立ssh连接
|
|
ssh = paramiko.SSHClient()
|
|
ssh.load_system_host_keys()
|
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
|
|
ssh.connect(
|
|
host,
|
|
port=port,
|
|
username=user,
|
|
key_filename=ssh_private_path % user,
|
|
compress=False)
|
|
|
|
# 建立交互式shell连接
|
|
width, height = os.get_terminal_size()
|
|
channel = ssh.invoke_shell("xterm-256color", width, height)
|
|
|
|
# 建立交互式管道
|
|
posix_shell(channel)
|
|
|
|
# 关闭连接
|
|
channel.close()
|
|
ssh.close()
|
|
|
|
|
|
# 初始化表
|
|
def connect_db():
|
|
# 连接数据库
|
|
conn = sqlite3.connect(gSqlite3File)
|
|
if conn == None:
|
|
print("sqlite3.connect " + gSqlite3File + "failed!")
|
|
exit(1)
|
|
return conn
|
|
|
|
def get_hosts(user):
|
|
db = connect_db()
|
|
hosts = db.execute("select name,ip,port from hosts where isdelete=0 and name in (select hostname from hostuser where username='%s') order by id" % user).fetchall()
|
|
resp = []
|
|
for host in hosts:
|
|
resp.append("%s:%s:%s" % (host[0], host[1], host[2]))
|
|
|
|
if len(resp) == 0:
|
|
print(user + " no valid hosts")
|
|
exit(1)
|
|
return resp
|
|
|
|
def main():
|
|
# 初始化检测
|
|
init()
|
|
|
|
user = getpass.getuser()
|
|
print("current user: " + user)
|
|
|
|
title = "ssh hosts select:"
|
|
menu = get_hosts(user)
|
|
|
|
while True:
|
|
option, index = Picker(menu, title).start()
|
|
arr = option.split(":")
|
|
if len(arr) == 3:
|
|
host = arr[1]
|
|
port = int(arr[2])
|
|
NewTerminal(host, port, user)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|