Initial implementation
This commit is contained in:
commit
3abbc1d83b
5
pgmon.cfg
Normal file
5
pgmon.cfg
Normal file
@ -0,0 +1,5 @@
|
||||
host=localhost
|
||||
port=54315
|
||||
metric=max_frozen_age:value::SELECT max(age(datfrozenxid)) FROM pg_database
|
||||
metric=db_stats:row::SELECT * FROM pg_stat_database WHERE datname = '{datname}'
|
||||
metric=discover_dbs:column::SELECT datname FROM pg_database
|
||||
676
pgmon.py
Executable file
676
pgmon.py
Executable file
@ -0,0 +1,676 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
|
||||
if "pytest" in sys.modules:
|
||||
# Conditional modules are needed for tests, so import them if this is a test
|
||||
import pytest
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def pytest_imports():
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import queue
|
||||
import os
|
||||
import signal
|
||||
import json
|
||||
|
||||
#
|
||||
# Errors
|
||||
#
|
||||
|
||||
class InvalidConfigError(Exception):
|
||||
pass
|
||||
class DuplicateMetricVersionError(Exception):
|
||||
pass
|
||||
class UnsupportedMetricVersionError(Exception):
|
||||
pass
|
||||
class RequestTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
#
|
||||
# Global variables
|
||||
#
|
||||
running = True
|
||||
|
||||
#
|
||||
# Signal handler
|
||||
#
|
||||
def signal_handler(sig, frame):
|
||||
# Restore the original handler
|
||||
signal.signal(signal.SIGINT, signal.default_int_handler)
|
||||
|
||||
# Signal everything to shut down
|
||||
if sig == signal.SIGINT:
|
||||
print("Shutting down ...")
|
||||
global running
|
||||
running = False
|
||||
|
||||
#
|
||||
# Classes
|
||||
#
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Agent configuration
|
||||
"""
|
||||
def __init__(self, config_file, read_metrics = True):
|
||||
# Set defaults
|
||||
self.ipc_socket = '/tmp/pgmon.sock' # IPC socket
|
||||
self.ipc_timeout = 10 # IPC communication timeout (s)
|
||||
self.request_timeout = 10 # Request processing timeout (s)
|
||||
self.request_queue_size = 100 # Max size of the request queue before it blocks
|
||||
self.request_queue_timeout = 2 # Max time to wait when queueing a request (s)
|
||||
self.worker_count = 4 # Number of worker threads
|
||||
|
||||
self.host = 'localhost' # DB host
|
||||
self.port = 5432 # DB port
|
||||
self.user = 'postgres' # DB user
|
||||
self.password = None # DB password
|
||||
self.dbname = 'postgres' # DB name
|
||||
|
||||
self.metrics = {} # Metrics
|
||||
|
||||
# Dynamic values
|
||||
self.pg_version = None # PostgreSQL version
|
||||
|
||||
# Read config
|
||||
self.read_file(config_file, read_metrics)
|
||||
|
||||
def read_file(self, config_file, read_metrics = True):
|
||||
with open(config_file, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
|
||||
if line.startswith('#'):
|
||||
continue
|
||||
elif line == '':
|
||||
continue
|
||||
|
||||
(key, value) = line.split('=', 1)
|
||||
if value is None:
|
||||
raise InvalidConfigError("{}: {}", config_file, line)
|
||||
|
||||
if key == 'include':
|
||||
self.read_file(value, read_metric)
|
||||
elif key == 'socket':
|
||||
self.ipc_socket = value
|
||||
elif key == 'socket_timeout':
|
||||
self.ipc_timeout = float(value)
|
||||
elif key == 'request_timeout':
|
||||
self.request_timeout = float(value)
|
||||
elif key == 'request_queue_size':
|
||||
self.request_queue_size = int(value)
|
||||
elif key == 'request_queue_timeout':
|
||||
self.request_queue_timeout = float(value)
|
||||
elif key == 'worker_count':
|
||||
self.worker_count = int(value)
|
||||
elif key == 'host':
|
||||
self.host = value
|
||||
elif key == 'port':
|
||||
self.port = int(value)
|
||||
elif key == 'user':
|
||||
self.user = value
|
||||
elif key == 'password':
|
||||
self.password = value
|
||||
elif key == 'dbname':
|
||||
self.dbname = value
|
||||
elif key == 'metric':
|
||||
if read_metrics:
|
||||
self.add_metric(value)
|
||||
|
||||
def add_metric(self, metric_def):
|
||||
(name, ret_type, version, sql) = metric_def.split(':', 3)
|
||||
if sql is None:
|
||||
raise InvalidConfigError
|
||||
|
||||
if sql.startswith('file:'):
|
||||
(_,path) = sql.split(':', 1)
|
||||
with open(path, 'r') as f:
|
||||
sql = f.read()
|
||||
|
||||
if version == '':
|
||||
version = 0
|
||||
|
||||
try:
|
||||
metric = self.metrics[name]
|
||||
except KeyError:
|
||||
metric = Metric(name, ret_type)
|
||||
self.metrics[name] = metric
|
||||
|
||||
metric.add_version(int(version), sql)
|
||||
|
||||
def get_pg_version(self):
|
||||
if self.pg_version is None:
|
||||
db = DB(self)
|
||||
self.pg_version = int(db.query('SHOW server_version_num')[0]['server_version_num'])
|
||||
db.close()
|
||||
return self.pg_version
|
||||
|
||||
|
||||
class DB:
|
||||
"""
|
||||
Database access
|
||||
"""
|
||||
def __init__(self, config, dbname=None):
|
||||
if dbname is None:
|
||||
dbname = config.dbname
|
||||
self.conn = psycopg2.connect(
|
||||
host = config.host,
|
||||
port = config.port,
|
||||
user = config.user,
|
||||
password = config.password,
|
||||
dbname = dbname);
|
||||
self.conn.set_session(readonly=True, autocommit=True)
|
||||
|
||||
def query(self, sql, args=[]):
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql, args)
|
||||
return(cur.fetchall())
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
if self.conn is not None:
|
||||
self.conn.close()
|
||||
except psycopg2.Error as e:
|
||||
print("Caught an error when closing a connection: {}".format(e))
|
||||
self.conn = None
|
||||
|
||||
|
||||
class Request:
|
||||
"""
|
||||
A metric request
|
||||
"""
|
||||
def __init__(self, key, args = {}):
|
||||
self.key = key
|
||||
self.args = args
|
||||
self.result = None
|
||||
|
||||
# Add a lock to indicate when the request is complete
|
||||
self.complete = threading.Lock()
|
||||
self.complete.acquire()
|
||||
|
||||
def set_result(self, result):
|
||||
self.result = result
|
||||
|
||||
# Release the lock
|
||||
self.complete.release()
|
||||
|
||||
def get_result(self, timeout = -1):
|
||||
# Wait until the request has been completed
|
||||
if self.complete.acquire(timeout = timeout):
|
||||
return self.result
|
||||
else:
|
||||
raise RequestTimeoutError
|
||||
|
||||
|
||||
class Metric:
|
||||
"""
|
||||
A metric
|
||||
"""
|
||||
def __init__(self, name, ret_type):
|
||||
self.name = name
|
||||
self.versions = {}
|
||||
self.type = ret_type
|
||||
self.cached = None
|
||||
self.cached_version = None
|
||||
|
||||
def add_version(self, version, sql):
|
||||
if version in self.versions:
|
||||
raise DuplicateMetricVersionError
|
||||
self.versions[version] = MetricVersion(sql)
|
||||
|
||||
def get_version(self, version):
|
||||
# Since we're usually going to keep asking for the same version,
|
||||
# we cache the metric version and only search through the versions
|
||||
# again if the PostgreSQL version changes
|
||||
if self.cached is None or self.cached_version != version:
|
||||
self.cached = None
|
||||
self.cached_version = None
|
||||
for v in reversed(sorted(self.versions.keys())):
|
||||
if version >= v:
|
||||
self.cached = self.versions[v]
|
||||
self.cached_version = version
|
||||
break
|
||||
if self.cached is None:
|
||||
raise UnsupportedMetricVersionError
|
||||
|
||||
return self.cached
|
||||
|
||||
|
||||
class MetricVersion:
|
||||
"""
|
||||
A version of a metric
|
||||
"""
|
||||
def __init__(self, sql):
|
||||
self.sql = sql
|
||||
|
||||
def get_sql(self, args):
|
||||
return self.sql.format(**args)
|
||||
|
||||
|
||||
class IPC:
|
||||
"""
|
||||
IPC handler
|
||||
"""
|
||||
def __init__(self, config, mode):
|
||||
self.config = config
|
||||
self.mode = mode
|
||||
self.reconnect()
|
||||
|
||||
def reconnect(self):
|
||||
if self.mode == 'server':
|
||||
# Try to clean up the socket if it exists
|
||||
try:
|
||||
os.unlink(self.config.ipc_socket)
|
||||
except OSError:
|
||||
if os.path.exists(self.config.ipc_socket):
|
||||
raise
|
||||
|
||||
# Create the socket
|
||||
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
# Set a timeout for accepting connections to be able to catch signals
|
||||
self.socket.settimeout(1)
|
||||
self.socket.bind(self.config.ipc_socket)
|
||||
self.socket.listen(1)
|
||||
|
||||
elif self.mode == 'agent':
|
||||
# Connect to the socket
|
||||
self.conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
self.conn.settimeout(self.config.ipc_timeout)
|
||||
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
def connect(self):
|
||||
"""
|
||||
Establish a connection to a socket
|
||||
"""
|
||||
self.conn.connect(self.config.ipc_socket)
|
||||
return self.conn
|
||||
|
||||
def accept(self):
|
||||
"""
|
||||
Accept a connection
|
||||
"""
|
||||
(conn, _) = self.socket.accept()
|
||||
conn.settimeout(self.config.ipc_timeout)
|
||||
return conn
|
||||
|
||||
def send(self, conn, msg):
|
||||
msg_bytes = msg.encode("utf-8")
|
||||
msg_len = len(msg_bytes)
|
||||
msg_len_bytes = msg_len.to_bytes(4, byteorder='big')
|
||||
conn.sendall(msg_len_bytes + msg_bytes)
|
||||
|
||||
def recv(self, conn, timeout=-1):
|
||||
# Read at least the length
|
||||
buffer = []
|
||||
msg_len = -4
|
||||
while msg_len < 0:
|
||||
buffer += conn.recv(1024)
|
||||
msg_len = len(buffer) - 4
|
||||
|
||||
msg_len_bytes = buffer[:4]
|
||||
msg_len = int.from_bytes(msg_len_bytes, byteorder='big')
|
||||
|
||||
while len(buffer) < msg_len + 4:
|
||||
buffer += conn.recv(1024)
|
||||
|
||||
return bytes(buffer[4:]).decode("utf-8")
|
||||
|
||||
|
||||
class Agent:
|
||||
"""
|
||||
The agent side of the connector
|
||||
"""
|
||||
@staticmethod
|
||||
def run(config, key, args):
|
||||
# Connect to the socket
|
||||
ipc = IPC(config, 'agent')
|
||||
try:
|
||||
conn = ipc.connect()
|
||||
|
||||
# Send a request
|
||||
ipc.send(conn, "{},{}".format(key, ",".join(args)))
|
||||
|
||||
# Wait for a response
|
||||
res = ipc.recv(conn)
|
||||
|
||||
except:
|
||||
print("IPC error")
|
||||
sys.exit(1)
|
||||
|
||||
# Output the response
|
||||
print(res)
|
||||
|
||||
|
||||
class Server:
|
||||
"""
|
||||
The server side of the connector
|
||||
"""
|
||||
@staticmethod
|
||||
def run(config):
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import queue
|
||||
import os
|
||||
import signal
|
||||
import json
|
||||
|
||||
# Set up the signal handler
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
# Create reqest queue
|
||||
req_queue = queue.Queue(config.request_queue_size)
|
||||
|
||||
# Spawn worker threads
|
||||
workers = [None] * config.worker_count
|
||||
for i in range(config.worker_count):
|
||||
workers[i] = Worker(config, req_queue)
|
||||
workers[i].start()
|
||||
|
||||
# Listen on ipc socket
|
||||
ipc = IPC(config, 'server')
|
||||
|
||||
while True:
|
||||
# Wait for a request connection
|
||||
try:
|
||||
conn = ipc.accept()
|
||||
except socket.timeout:
|
||||
conn = None
|
||||
|
||||
# See if we should exit
|
||||
if not running:
|
||||
break
|
||||
|
||||
# If we just timed out waiting for a request, go back to waiting
|
||||
if conn is None:
|
||||
continue
|
||||
|
||||
# Get the request string (csv)
|
||||
req_csv = ipc.recv(conn)
|
||||
|
||||
# Receive ipc request (csv)
|
||||
try:
|
||||
(key, args_csv) = req_csv.split(',', 1)
|
||||
args_dict = {}
|
||||
if args_csv != "":
|
||||
for (k, v) in [a.split('=', 1) for a in args_csv.split(',')]:
|
||||
args_dict[k] = v
|
||||
except socket.timeout:
|
||||
print("IPC communication timeout receiving request")
|
||||
conn.close()
|
||||
continue
|
||||
except Exception:
|
||||
print("Received invalid request: '{}'".format(req_csv))
|
||||
conn.close()
|
||||
continue
|
||||
|
||||
# Create request object
|
||||
req = Request(key, args_dict)
|
||||
|
||||
# Queue the request
|
||||
try:
|
||||
req_queue.put(req, timeout=config.request_queue_timeout)
|
||||
except:
|
||||
print("Failed to queue request")
|
||||
req.set_result("ERROR: Queue timeout")
|
||||
continue
|
||||
|
||||
# Spawn a thread to wait for the result
|
||||
r = Responder(config, ipc, conn, req)
|
||||
r.start()
|
||||
|
||||
# Join worker threads
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
|
||||
print("Good bye")
|
||||
|
||||
|
||||
class Worker(threading.Thread):
|
||||
"""
|
||||
Worker thread that processes requests (ie: queries the database)
|
||||
"""
|
||||
def __init__(self, config, queue):
|
||||
super(Worker, self).__init__()
|
||||
|
||||
self.config = config
|
||||
self.queue = queue
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
# Wait for a request
|
||||
try:
|
||||
req = self.queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
req = None
|
||||
|
||||
# Check if we're supposed to exit
|
||||
if not running:
|
||||
break
|
||||
|
||||
# If we got here because we waited too long for a request, go back to waiting
|
||||
if req is None:
|
||||
continue
|
||||
|
||||
# Find the requested metrtic
|
||||
try:
|
||||
metric = self.config.metrics[req.key]
|
||||
except KeyError:
|
||||
req.set_result("ERROR: Unknown key '{}'".format(req.key))
|
||||
continue
|
||||
|
||||
# Get the DB version
|
||||
try:
|
||||
pg_version = self.config.get_pg_version()
|
||||
except Exception as e:
|
||||
req.set_result("Failed to retrieve database version")
|
||||
print("Error: {}".format(e))
|
||||
continue
|
||||
|
||||
# Get the query to use
|
||||
try:
|
||||
mv = metric.get_version(pg_version)
|
||||
except UnsupportedMetricVersionError:
|
||||
req.set_result("Unsupported PosgreSQL version for metric")
|
||||
continue
|
||||
|
||||
# Query the database
|
||||
try:
|
||||
dbname = req.args['db']
|
||||
except KeyError:
|
||||
dbname = None
|
||||
|
||||
try:
|
||||
db = DB(self.config, dbname)
|
||||
res = db.query(mv.get_sql(req.args))
|
||||
db.close()
|
||||
except Exception as e:
|
||||
# Make sure the database connectin is closed (ie: if the query timed out)
|
||||
try:
|
||||
db.close()
|
||||
except:
|
||||
pass
|
||||
req.set_result("Failed to query database")
|
||||
print("Error: {}".format(e))
|
||||
continue
|
||||
|
||||
# Set the result on the request
|
||||
if len(res) == 0:
|
||||
req.set_result("Empty result set")
|
||||
elif metric.type == 'value':
|
||||
req.set_result("{}".format(list(res[0].values())[0]))
|
||||
elif metric.type == 'row':
|
||||
req.set_result(json.dumps(res[0]))
|
||||
elif metric.type == 'column':
|
||||
req.set_result(json.dumps([list(r.values())[0] for r in res]))
|
||||
else:
|
||||
req.set_result(json.dumps(res))
|
||||
|
||||
try:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
print("Failed to close database connection")
|
||||
print("Error: {}".format(e))
|
||||
|
||||
class Responder(threading.Thread):
|
||||
"""
|
||||
Thread responsible for replying to requests
|
||||
"""
|
||||
def __init__(self, config, ipc, conn, req):
|
||||
super(Responder, self).__init__()
|
||||
|
||||
# Wait for a result
|
||||
result = req.get_result()
|
||||
|
||||
# Send the result back to the client
|
||||
ipc.send(conn, result)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='pgmon',
|
||||
description='A briidge between monitoring tools and PostgreSQL')
|
||||
|
||||
# General options
|
||||
parser.add_argument('-c', '--config', default='pgmon.cfg')
|
||||
parser.add_argument('-v', '--verbose', action='store_true')
|
||||
|
||||
# Operational mode
|
||||
parser.add_argument('-m', '--mode', choices=['agent', 'server'], required=True)
|
||||
|
||||
# Agent options
|
||||
parser.add_argument('-k', '--key')
|
||||
parser.add_argument('-a', '--args', nargs='*', default=[])
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == 'agent':
|
||||
config = Config(args.config, read_metrics = False)
|
||||
Agent.run(config, args.key, args.args)
|
||||
|
||||
else:
|
||||
config = Config(args.config)
|
||||
Server.run(config)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
##
|
||||
# Tests
|
||||
##
|
||||
class TestConfig:
|
||||
pass
|
||||
|
||||
class TestDB:
|
||||
pass
|
||||
|
||||
class TestRequest:
|
||||
def test_request_creation(self):
|
||||
# Create result with no args
|
||||
req1 = Request('foo', {})
|
||||
assert req1.key == 'foo'
|
||||
assert len(req1.args) == 0
|
||||
assert req1.complete.locked()
|
||||
|
||||
# Create result with args
|
||||
req2 = Request('foo', {'arg1': 'value1', 'arg2': 'value2'})
|
||||
assert req2.key == 'foo'
|
||||
assert len(req2.args) == 2
|
||||
assert req2.complete.locked()
|
||||
|
||||
def test_request_lock(self):
|
||||
req1 = Request('foo', {})
|
||||
assert req1.complete.locked()
|
||||
req1.set_result('blah')
|
||||
assert not req1.complete.locked()
|
||||
assert 'blah' == req1.get_result()
|
||||
|
||||
|
||||
|
||||
class TestMetric:
|
||||
def test_metric_creation(self):
|
||||
# Test basic creation
|
||||
m1 = Metric('foo', 'value')
|
||||
assert m1.name == 'foo'
|
||||
assert len(m1.versions) == 0
|
||||
assert m1.type == 'value'
|
||||
assert m1.cached is None
|
||||
assert m1.cached_version is None
|
||||
|
||||
def test_metric_add_version(self):
|
||||
# Test adding versions
|
||||
m1 = Metric('foo', 'value')
|
||||
assert len(m1.versions) == 0
|
||||
m1.add_version(0, 'default')
|
||||
assert len(m1.versions) == 1
|
||||
m1.add_version(120003, 'v12.3')
|
||||
assert len(m1.versions) == 2
|
||||
|
||||
# Make sure added versions are correct
|
||||
assert m1.versions[0].sql == 'default'
|
||||
assert m1.versions[120003].sql == 'v12.3'
|
||||
|
||||
def test_metric_get_version(self):
|
||||
# Test retrieving metric versions
|
||||
m1 = Metric('foo', 'value')
|
||||
m1.add_version(100000, 'v10.0')
|
||||
m1.add_version(120000, 'v12.0')
|
||||
|
||||
# Make sure cache is initially empty
|
||||
assert m1.cached is None
|
||||
assert m1.cached_version is None
|
||||
|
||||
assert m1.get_version(110003).sql == 'v10.0'
|
||||
|
||||
# Make sure cache is set
|
||||
assert m1.cached is not None
|
||||
assert m1.cached_version is 110003
|
||||
|
||||
# Make sure returned value changes with version
|
||||
assert m1.get_version(120000).sql == 'v12.0'
|
||||
assert m1.get_version(150005).sql == 'v12.0'
|
||||
|
||||
# Make sure an error is thrown when no version matches
|
||||
with pytest.raises(UnsupportedMetricVersionError):
|
||||
m1.get_version(90603)
|
||||
|
||||
# Add a default version
|
||||
m1.add_version(0, 'default')
|
||||
assert m1.get_version(90603).sql == 'default'
|
||||
assert m1.get_version(110003).sql == 'v10.0'
|
||||
assert m1.get_version(120000).sql == 'v12.0'
|
||||
assert m1.get_version(150005).sql == 'v12.0'
|
||||
|
||||
class TestMetricVersion:
|
||||
def test_metric_version_creation(self):
|
||||
mv1 = MetricVersion('test')
|
||||
assert mv1.sql == 'test'
|
||||
|
||||
def test_metric_version_templating(self):
|
||||
mv1 = MetricVersion('foo')
|
||||
assert mv1.get_sql({}) == 'foo'
|
||||
|
||||
mv2 = MetricVersion('foo {a1} {a3} {a2}')
|
||||
assert mv2.get_sql({'a1': 'bar', 'a2': 'blah blah blah', 'a3': 'baz'}) == 'foo bar baz blah blah blah'
|
||||
|
||||
|
||||
class TestIPC:
|
||||
pass
|
||||
|
||||
class TestAgent:
|
||||
pass
|
||||
|
||||
class TestServer:
|
||||
pass
|
||||
|
||||
class TestWorker:
|
||||
pass
|
||||
|
||||
class TestResponder:
|
||||
pass
|
||||
Loading…
Reference in New Issue
Block a user