pgmon/pgmon.py
2024-06-29 17:24:49 -04:00

1475 lines
47 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import socket
import sys
import threading
import psycopg2
import psycopg2.extras
import queue
import os
import signal
import json
import logging
#
# Errors
#
class InvalidConfigError(Exception):
pass
class UnknownClusterError(Exception):
pass
class DuplicateMetricVersionError(Exception):
pass
class UnsupportedMetricVersionError(Exception):
pass
class RequestTimeoutError(Exception):
pass
class DBError(Exception):
pass
#
# Logging
#
logger = None
# Store the current log file since there is no public method to get the filename
# from a FileHandler object
current_log_file = None
# Handler objects for easy adding/removing/modifying
file_log_handler = None
stderr_log_handler = None
def init_logging(config):
"""
Initialize (or re-initialize/modify) logging
"""
global logger
global current_log_file
global file_log_handler
global stderr_log_handler
# Get the logger object
logger = logging.getLogger(__name__)
# Create a common formatter
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
# Set up or modify stderr logging
if config.stderr_log_level != 'OFF':
# Create and add the handler if it doesn't exist
if stderr_log_handler is None:
stderr_log_handler = logging.StreamHandler()
logger.addHandler(stderr_log_handler)
# Set the formatter
stderr_log_handler.setFormatter(formatter)
# Set the log level
level = logging.getLevelName(config.stderr_log_level)
stderr_log_handler.setLevel(level)
else:
if stderr_log_handler is not None:
logger.removeHandler(stderr_log_handler)
stderr_log_handler = None
# Set up or modify file logging
if config.file_log_level != 'OFF':
# Checck if we're switching files
if file_log_handler is not None and config.log_file != current_log_file:
old_file_logger = file_log_handler
file_log_handler = None
else:
old_file_logger = None
if config.log_file is not None:
# Create and add the handler if it doesn't exist
if file_log_handler is None:
file_log_handler = logging.FileHandler(config.log_file, encoding='utf-8')
logger.addHandler(file_log_handler)
current_log_file = config.log_file
# Set the formatter
file_log_handler.setFormatter(formatter)
# Set the log level
level = logging.getLevelName(config.file_log_level)
file_log_handler.setLevel(level)
# Remove the old handler if there was one
if old_file_logger is not None:
logger.removeHandler(old_file_logger)
# Note where logs are being written
#print("Logging to {} ({})".format(config.log_file, config.file_log_level))
else:
if file_log_handler is not None:
logger.removeHandler(file_log_handler)
file_log_handler = None
# Set the log level for the logger itself
levels = []
if stderr_log_handler is not None and stderr_log_handler.level != logging.NOTSET:
levels.append(stderr_log_handler.level)
if file_log_handler is not None and file_log_handler.level != logging.NOTSET:
levels.append(file_log_handler.level)
if len(levels) > 0:
logger.setLevel(min(levels))
else:
# If we have no handlers, just bump the level to the max
logger.setLevel(logging.CRITICAL)
#
# PID file handling
#
def write_pid_file(pid_file):
if pid_file is not None:
with open(pid_file, 'w') as f:
f.write("{}".format(os.getpid()))
def read_pid_file(pid_file):
if pid_file is None:
raise RuntimeError("No PID file specified")
with open(pid_file, 'r') as f:
return int(f.read().strip())
def remove_pid_file(pid_file):
if pid_file is not None:
os.unlink(pid_file)
#
# Global flags for signal handler
#
running = True
reload = False
#
# Signal handler
#
def signal_handler(sig, frame):
"""
Function for handling signals
INT => Shot down
TERM => Shut down
QUIT => Shut down
HUP => Reload
"""
# Restore the original handler
signal.signal(signal.SIGINT, signal.default_int_handler)
# Signal everything to shut down
if sig in [ signal.SIGINT, signal.SIGTERM, signal.SIGQUIT ]:
logger.info("Shutting down ...")
global running
running = False
# Signal a reload
elif sig == signal.SIGHUP:
logger.info("Reloading config ...")
global reload
reload = True
#
# Classes
#
class Config:
"""
Agent configuration
Note: The config is initially loaded before logging is configured, so be
mindful about logging anything in this class.
Params:
args: (argparse.Namespace) Command line arguments
read_metrics: (bool) Indicate if metrics should be parsed
read_clusters: (bool) Indicate if cluster information should be parsed
Exceptions:
InvalidConfigError: Indicates an issue with the config file
OSError: Thrown if there's an issue opening a config file
ValueError: Thrown if there is an encoding error
"""
def __init__(self, args, read_metrics = True, read_clusters = True):
# Set defaults
self.pid_file = None # PID file
self.ipc_socket = 'pgmon.sock' # IPC socket
self.ipc_timeout = 10 # IPC communication timeout (s)
self.request_timeout = 10 # Request processing timeout (s)
self.request_queue_size = 100 # Max size of the request queue before it blocks
self.request_queue_timeout = 2 # Max time to wait when queueing a request (s)
self.worker_count = 4 # Number of worker threads
self.stderr_log_level = 'INFO' # Log level for stderr logging (or 'off')
self.file_log_level = 'INFO' # Log level for file logging (od 'off')
self.log_file = None # Log file
self.metrics = {} # Metrics
self.clusters = {} # Known clusters
# Check if we have a config file
self.config_file = args.config
# Read config
if self.config_file is not None:
self.read_file(self.config_file, read_metrics, read_clusters)
# Override anything that was specified on the command line
if args.pidfile is not None:
self.pid_file = args.pidfile
if args.logfile is not None:
self.log_file = args.logfile
if args.socket is not None:
self.ipc_socket = args.socket
def read_file(self, config_file, read_metrics, read_clusters):
"""
Read a config file, possibly skipping metrics and clusters to lighten
the load on the agent.
Params:
config_file: (str) Path to the config file
read_metrics: (bool) Indicate if metrics should be parsed
read_clusters: (bool) Indicate if cluster information should be parsed
Exceptions:
InvalidConfigError: Indicates an issue with the config file
"""
try:
with open(config_file, 'r') as f:
for line in f:
# Clean up any whitespace at either end
line = line.strip()
# Skip empty lines and comments
if line.startswith('#'):
continue
elif line == '':
continue
# Separate the line into a key-value pair and clean up extra
# white space that may have been around the '='
(key, value) = [x.strip() for x in line.split('=', 1)]
if value is None:
raise InvalidConfigError("{}: {}", config_file, line)
# Handle each key appropriately
if key == 'include':
print("Including file: {}".format(value))
self.read_file(value, read_metrics, read_clusters)
elif key == 'pid_file':
self.pid_file = value
elif key == 'ipc_socket':
self.ipc_socket = value
elif key == 'ipc_timeout':
self.ipc_timeout = float(value)
elif key == 'request_timeout':
self.request_timeout = float(value)
elif key == 'request_queue_size':
self.request_queue_size = int(value)
elif key == 'request_queue_timeout':
self.request_queue_timeout = float(value)
elif key == 'worker_count':
self.worker_count = int(value)
elif key == 'stderr_log_level':
self.stderr_log_level = value.upper()
elif key == 'file_log_level':
self.file_log_level = value.upper()
elif key == 'log_file':
self.log_file = value
elif key == 'cluster':
if read_clusters:
self.add_cluster(value)
elif key == 'metric':
if read_metrics:
print("Adding metric: {}".format(value))
self.add_metric(value)
else:
raise InvalidConfigError("WARNING: Unknown config option: {}".format(key))
except OSError as e:
raise InvalidConfigError("Failed to open/read config file: {}".format(e))
except ValueError as e:
raise InvalidConfigError("Encoding error in config file: {}".format(e))
def add_cluster(self, cluster_def):
"""
Parse and add connection information about a cluster to the config.
Each cluster line is of the format:
<name>:[address]:[port]:[dbname]:[user]:password
The name is a unique, arbitrary identifier to associate this cluster
with a request from the monitoring agent.
The address can be an IP address, host name, or path the the directory
where a PostgreSQL socket exists.
The dbname field is the default database to connect to for metrics
which don't specify a database. This is also used to identify the
PostgreSQL version.
Default values replace empty conponents, except for the name field:
address: /var/run/postgresql
port: 5432
dbname: postgres
user: postgres
password: None
Params:
cluster_def: (str) Cluster definition string
Exceptions:
InvalidConfigError: Thrown if the cluster entry is missing any fields
or contains invalid content
"""
# Split up the fields
try:
(name, address, port, dbname, user, password) = cluster_def.split(':')
except ValueError:
raise InvalidConfigError("Missing fields in cluster definition: {}".format(cluster_def))
# Make sure we have a name
if name == '':
raise InvalidConfigError("Cluster must have a name: {}".format(cluster_def))
# Set defaults for anything that's blank
if address == '':
address = '/var/run/postgresql'
if port == '':
port = 5432
else:
# Convert the port to a number here
try:
port = int(port)
except ValueError:
raise InvalidConfigError("Invalid port number: {}".format(port))
if dbname == '':
dbname = 'postgres'
if user == '':
user = 'postgres'
if password == '':
password = None
# Create and add the cluster object
self.clusters[name] = Cluster(name, address, port, dbname, user, password)
def add_metric(self, metric_def):
"""
Parse and add a metric or metric version to the config.
Each metric definition is of the form:
<name>:[return type]:[PostgreSQL version]:<sql>
The name is an identifier used to reference this metric from the
monitoring server.
The return type indicates how to format the results. Possible values
are:
value: A single value is returned (first column of first record)
column: A list is returned containing the first column of all records
set: All records are returned as a list of dicts
The version field is the first version of PostgreSQL for which this
query is valid.
The sql field if either the SQL to execute, or a string of the form:
file:<path>
where <path> is the path to a file containing the SQL. In either case,
the SQL can contain references to parameters that are to be passed to
PostgreSQL. These are substituted using python's format command, so
variable names should be enclosed in curly brackets like {foo}
Params:
metric_def: (str) Metric definition string
Exceptions:
InvalidConfigError: Thrown if the metric entry is missing any fields
or contains invalid content
"""
# Split up the fields
try:
(name, ret_type, version, sql) = metric_def.split(':', 3)
except ValueError:
raise InvalidConfigError("Missing fields in metric definition: {}".format(metric_def))
# Make sure we have a name and some SQL
if name == '':
raise InvalidConfigError("Missing name for metric: {}".format(metric_def))
# An empty SQL query indicates a metric is not suported after a particular version
if sql == '':
sql = None
# If the sql is in a separate file, read it in
try:
if sql.startswith('file:'):
(_,path) = sql.split(':', 1)
with open(path, 'r') as f:
sql = f.read()
except OSError as e:
raise InvalidConfigError("Failed to open/read SQL file: {}".format(e))
except ValueError as e:
raise InvalidConfigError("Encoding error in SQL file: {}".format(e))
# If no starting version is given, set it to 0
if version == '':
version = 0
# Find or create the metric
try:
metric = self.metrics[name]
except KeyError:
metric = Metric(name, ret_type)
self.metrics[name] = metric
# Add what was given as a version of the metric
metric.add_version(int(version), sql)
def get_pg_version(self, cluster_name):
"""
Return the version of PostgreSQL running on the specified cluster. The
version is cached after the first successful retrieval.
Params:
cluster_name: (str) The identifier for the cluster, as defined in the
config file
Returns:
The version using PostgreSQL's integer format (Mmmpp or MM00mm)
Exceptions:
UnknownClusterError: Thrown if the named cluster is not defined
DBError: A database error occurred
"""
# TODO: Move this out of the Config class. Possibly move the code
# to the DB class, while storing the value in the Cluster object.
# Find the cluster
try:
cluster = self.clusters[cluster_name]
except KeyError:
raise UnknownClusterError(cluster_name)
# Query the cluster if we don't already know the version
# TODO: Expire this value at some point to pick up upgrades.
if cluster.pg_version is None:
db = DB(self, cluster_name)
cluster.pg_version = int(db.query('SHOW server_version_num')[0]['server_version_num'])
db.close()
# Return the version number
return cluster.pg_version
class Cluster:
"""
Connection information for a PostgreSQL cluster
Params:
name: A unique, arbitrary identifier to associate this cluster with a
request from the monitoring agent.
address: An IP address, host name, or path the the directory where a
PostgreSQL socket exists.
port: The database port number.
dbname: Default database to connect to for metrics which don't specify a
database. This is also used to identify the PostgreSQL version.
user: The database user to connect as.
password: The password to use when connecting to the database. Leave
this blank to use either no password or a password stored in
the executing user's ~/.pgpass file.
"""
def __init__(self, name, address, port, dbname, user, password):
self.name = name
self.address = address
self.dbname = dbname
self.port = port
self.user = user
self.password = password
# Dynamically acquired PostgreSQL version
self.pg_version = None
class DB:
"""
Database access
Params:
config: The agent config
cluster_name: The name of the cluster to connect to
dbname: A database to connect to, or None to use the default
Exceptions:
UnknownClusterError: Thrown if the named cluster is not defined
DBError: Thrown for any database related error
"""
def __init__(self, config, cluster_name, dbname=None):
logger.debug("Creating connection to cluster: {}".format(cluster_name))
# Find the named cluster
try:
cluster = config.clusters[cluster_name]
except KeyError:
raise UnknownClusterError(cluster_name)
# Use the default database if not given one
if dbname is None:
logger.debug("Using default database: {}".format(cluster.dbname))
dbname = cluster.dbname
# Connect to the database
try:
self.conn = psycopg2.connect(
host = cluster.address,
port = cluster.port,
user = cluster.user,
password = cluster.password,
dbname = dbname);
except Exception as e:
raise DBError("Failed to connect to the database: {}".format(e))
# Make the connection readonly and enable autocommit
try:
self.conn.set_session(readonly=True, autocommit=True)
except Exception as e:
raise DBError("Failed to set session parameters: {}".format(e))
def query(self, sql, args=[]):
"""
Execute a query in the database.
Params:
sql: (str) The SQL statement to execute
args: (list) List of positional arguments for the query
Exceptions:
DBError: Thrown for any database related error
"""
logger.debug("Executuing query: {}".format(sql))
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, args)
return(cur.fetchall())
except Exception as e:
raise DBError("Failed to execute query: ".format(e))
def close(self):
logger.debug("Closing db connection")
try:
if self.conn is not None:
self.conn.close()
except psycopg2.Error as e:
logger.warning("Caught an error when closing a connection: {}".format(e))
self.conn = None
class Request:
"""
A metric request
Params:
cluster_name: (str) The name of the cluster the request is for
metric_name: (str) The name of the metric to obtain
args: (dict) Dictionary of arguments for the metric
"""
def __init__(self, cluster_name, metric_name, args = {}):
self.cluster_name = cluster_name
self.metric_name = metric_name
self.args = args
self.result = None
# Add a lock to indicate when the request is complete
self.complete = threading.Lock()
self.complete.acquire()
def set_result(self, result):
"""
Set the result for the metric, and release the lock that allows the
result to be returned to the client.
"""
# Set the result value
self.result = result
# Release the lock
self.complete.release()
def get_result(self, timeout = -1):
"""
Retrieve the result for the metric. This will wait for a result to be
available. If timeout is >0, a RequestTimeoutError exception will be
thrown if the specified number of seconds elapses.
Params:
timeout: (float) Number of seconds to wait before timing out
"""
# Wait until the request has been completed
if self.complete.acquire(timeout = timeout):
return self.result
else:
raise RequestTimeoutError()
class Metric:
"""
A metric
The return_type parameter controls how the results will be formatted when
returned to the client. Possible values are:
value: Return a single value
column: Resturn a list conprised of the first column of the query results
set: Return a list of dictionaries representing the queried records
Params:
name: (str) The name to associate with the metric
ret_type: (str) The return type of the metric (value, column, or set)
"""
def __init__(self, name, ret_type):
self.name = name
self.versions = {}
self.type = ret_type
# Place holders for the query cache
self.cached = {}
def add_version(self, pg_version, sql):
"""
Add a versioned query for the metric
Params:
pg_version: (int) The first version number for which this query
applies, or 0 for any version
sql: (str) The SQL to execute for the specified version
"""
# Check if we already have SQL for this version
if pg_version in self.versions:
raise DuplicateMetricVersionError
# Add the versioned SQL
self.versions[pg_version] = MetricVersion(sql)
def get_version(self, pg_version):
"""
Get an apropriate SQL query for the viven PostgreSQL version
Params:
pg_version: (int) The version of PostgreSQL for which to find a query
"""
# Since we're usually going to keep asking for the same version(s),
# we cache the metric version the first time we need it.
if pg_version not in self.cached:
self.cached[pg_version] = None
# Search through the cache starting from the lowest version until
# we find a supporting version
for v in reversed(sorted(self.versions.keys())):
if pg_version >= v:
self.cached[pg_version] = self.versions[v]
break
# If we didn't find a query, or the query is None (ie: the metric is
# no longer supported for this version), throw an exception
if pg_version not in self.cached or self.cached[pg_version] is None:
raise UnsupportedMetricVersionError
# Return the cached version
return self.cached[pg_version]
class MetricVersion:
"""
A version of a metric
The query is formatted using str.format with a dictionary of variables,
so you can use tings like {table} in the template, then pass table=foo when
retrieving the SQL and '{table}' will be replaced with 'foo'.
Note: Only minimal SQL injection checking is done on this. Basically we
just throw an error if any substitution value contains an apostrophe
or a semicolon
Params:
sql: (str) The SQL query template
"""
def __init__(self, sql):
self.sql = sql
def get_sql(self, args):
"""
Return the SQL for this version after substituting the provided
variables.
Params:
args: (dict) Dictionary of formatting substitutions to apply to the
query template
Exceptions:
SQLInjectionError: if any of the values to be substituted contain an
apostrophe (') or semicolon (;)
"""
# Check for possible SQL injection attacks
for v in args.values():
if "'" in v or ';' in v:
raise SQLInjectionError()
# Format and return the SQL
return self.sql.format(**args)
class IPC:
"""
IPC handler for communication between the agent and server components
Params:
config: (Config) The agent configuration object
mode: (str) Which side of the communication setup this is (agent|server)
Exceptions:
RuneimeError: if the IPC object's mode is invalid
"""
def __init__(self, config, mode):
# Validate the mode
if mode not in [ 'agent', 'server' ]:
raise RuntimeError("Invalid IPC mode: {}".format(self.mode))
self.config = config
self.mode = mode
# Set up the connection
try:
self.initialize()
except Exception as e:
logger.debug("IPC Initialization error: {}".format(e))
raise
def initialize(self):
"""
Initialize the socket
Exceptions:
OSError: if the socket file already exists and could not be removed
RuntimeError: if the IPC object's mode is invalid
"""
logger.debug("Initializing IPC")
if self.mode == 'server':
# Try to clean up the socket if it exists
try:
logger.debug("Unlinking any former socket")
os.unlink(self.config.ipc_socket)
except OSError:
logger.debug("Caught an exception unlinking socket")
if os.path.exists(self.config.ipc_socket):
logger.debug("Socket stilll exists")
raise
logger.debug("No socket to unlink")
# Create the socket
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
# Set a timeout for accepting connections to be able to catch signals
self.socket.settimeout(1)
# Bind the socket and start listening
logger.debug("Binding socket: {}".format(self.config.ipc_socket))
self.socket.bind(self.config.ipc_socket)
logger.debug("Listening on socket")
self.socket.listen(1)
elif self.mode == 'agent':
# Connect to the socket
self.conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
# Set the timeout for the agent side of the connection
self.conn.settimeout(self.config.ipc_timeout)
else:
raise RuntimeError("Invalid IPC mode: {}".format(self.mode))
logger.debug("IPC initialization complete")
def connect(self):
"""
Establish a connection to a socket (agent mode)
Establish and return a connection to the socket
Returns:
The connected socket object
Exceptions:
TimeoutError: if establishing the connection times out
"""
# Connect to the socket
self.conn.connect(self.config.ipc_socket)
# Return the socket from this IPC object
return self.conn
def accept(self):
"""
Accept a connection (server)
Wait for, accept, and return a connection to the socket
Returns:
The accepted socket object
Exceptions:
TimeoutError: if no connection is accepted before the timeout
"""
# Accept the connection
(conn, _) = self.socket.accept()
# Set the timeout on the new socket object
conn.settimeout(self.config.ipc_timeout)
# Return the new socket object
return conn
def send(self, conn, msg):
"""
Send a message across the IPC channel
The message is encoded in UTF-8, and prefixed with a 4 byte, big endian
unsigned integer
Params:
conn: (Socket) the connection to use
msg: (str) The message to send
Exceptions:
TimeoutError: if no connection times out before finishing sending the
message
OverflowError: if the size of the message exceeds what fits in a four
byte, unsigned integer
UnicodeError: if the message is not UTF-8 encodable
"""
# Encode the message as bytes
msg_bytes = msg.encode("utf-8")
# Get the byte length of the message
msg_len = len(msg_bytes)
# Encode the message length
msg_len_bytes = msg_len.to_bytes(4, byteorder='big')
# Send the size and message bytes
conn.sendall(msg_len_bytes + msg_bytes)
def recv(self, conn):
"""
Receive a message across the IPC channel
The message is encoded in UTF-8, and prefixed with a 4 byte, big endian
unsigned integer
Params:
conn: (Socket) the connection to use
Returns:
The received message
Exceptions:
TimeoutError: if no connection times out before finishing receiving
the whole message
UnicodeError: if the message is not UTF-8 decodable
"""
# Read at least the length
buffer = []
msg_len = -4
while msg_len < 0:
buffer += conn.recv(1024)
msg_len = len(buffer) - 4
# Pull out the bytes for the length
msg_len_bytes = buffer[:4]
# Convert the size bytes to an unsigned integer
msg_len = int.from_bytes(msg_len_bytes, byteorder='big')
# Finish reading the message if we don't have all of it
while len(buffer) < msg_len + 4:
buffer += conn.recv(1024)
# Decode and return the message
return bytes(buffer[4:]).decode("utf-8")
class Agent:
"""
The agent side of the connector
This mode is entended to be called byt the monitoring agent and
communicates with the server side of the connector.
The key is a comma separated string formatted as: TODO
Results are printed to stdout.
Params:
args: (argparse.Namespace) Command line arguments
key: (str) The key indicating the requested metric
"""
@staticmethod
def run(args, key):
# Read the agent config
config = Config(args, read_metrics = False, read_clusters = False)
init_logging(config)
# Connect to the IPC socket
ipc = IPC(config, 'agent')
try:
conn = ipc.connect()
# Send a request
ipc.send(conn, key)
# Wait for a response
res = ipc.recv(conn)
except Exception as e:
print("IPC error: {}".format(e))
sys.exit(1)
# Output the response
print(res)
class Server:
"""
The server side of the connector
Params:
args: (argparse.Namespace) Command line arguments
Exceptions:
"""
def __init__(self, args):
# Note the path to the config file so it can be reloaded
self.config_file = args.config
# Load config
self.config = Config(args)
# Write pid file
# Note: we record the PID file here so it can't be changed with reload
self.pid_file = self.config.pid_file
write_pid_file(self.pid_file)
# Initialize logging
init_logging(self.config)
# Set up the signal handler
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGHUP, signal_handler)
# Create reqest queue
logger.debug("Creating request queue")
self.req_queue = queue.Queue(self.config.request_queue_size)
# Spawn worker threads
logger.debug("Spawning worker threads")
self.workers = self.spawn_workers(self.config)
def spawn_workers(self, config):
"""
Spawn all worker threads to process requests
Params:
config: (Config) The agent config object
"""
logger.info("Spawning {} workers".format(config.worker_count))
# Spawn worker threads
workers = [None] * config.worker_count
for i in range(config.worker_count):
workers[i] = Worker(config, self.req_queue)
workers[i].start()
logger.debug("Started thread #{} (tid={})".format(i, workers[i].native_id))
# Return the list of worker threads
return workers
def retire_workers(self, workers):
"""
Retire (terminate) all worker threads in the given list of threads
Params:
workers: (list) List of worker threads to part ways with
"""
logger.info("Retiring {} workers".format(len(workers)))
# Inform the workers that their services are no longer required
for worker in workers:
worker.active = False
# Wait for the workers to turn in their badges
for worker in workers:
worker.join()
def reload_config(self):
"""
Reload the config and return a new config object, and spawn new worker
threads
Exceptions:
InvalidConfigError: Indicates an issue with the config file
OSError: Thrown if there's an issue opening a config file
ValueError: Thrown if there is an encoding error when reading the
config
"""
# Clear the reload flag
global reload
reload = False
# Read new config
new_config = Config(self.config_file)
# Re-init logging in case settings changed
init_logging(new_config)
# Spawn new workers
new_workers = self.spawn_workers(new_config)
# Replace workers
old_workers = self.workers
self.workers = new_workers
# Retire old workers
self.retire_workers(old_workers)
# Adjust other settings
# TODO
# Set the new config as the one the server will use
self.config = new_config
def run(self):
"""
Run the server's main loop
"""
logger.info("Server starting")
# Listen on ipc socket
ipc = IPC(self.config, 'server')
# Enter the main loop
while True:
# Wait for a request connection
try:
logger.debug("Waiting for a connection")
conn = ipc.accept()
except socket.timeout:
conn = None
# See if we should exit
if not running:
break
# Check if we're supposed to reload the config
if reload:
try:
self.reload_config()
except Exception as e:
logger.error("Reload failed: {}".format(e))
# If we just timed out waiting for a request, go back to waiting
if conn is None:
continue
# Get the request string (csv)
try:
key = ipc.recv(conn)
except socket.timeout:
# Handle timeouts when receiving the request
logger.warning("IPC communication timeout receiving request")
conn.close()
continue
# Parse ipc request (csv)
logger.debug("Parsing request key: {}".format(key))
try:
# Split the key into a cluster name, metric name, and list of
# metric arguments
parts = key.split(',', 3)
agent_name = parts[0]
cluster_name = parts[1]
metric_name = parts[2]
# Parse any metric arguments into a dictionary
args_dict = {'agent': agent_name, 'cluster': cluster_name}
if len(parts) > 3:
for arg in parts[3].split(','):
if arg != '':
(k, v) = arg.split('=', 1)
args_dict[k] = v
except Exception as e:
# Handle problems parsing the request into its elements
logger.warning("Received invalid request '{}': {}".format(key, e))
ipc.send(conn, "ERROR: Invalid key")
conn.close()
continue
# Create request object
req = Request(cluster_name, metric_name, args_dict)
# Queue the request
try:
self.req_queue.put(req, timeout=self.config.request_queue_timeout)
except queue.Full:
# Handle situations where the queue is full and we didn't get
# a free slot before the configured timeout
logger.warning("Failed to queue request, queue is full")
req.set_result("ERROR: Enqueue timeout")
continue
# Spawn a thread to wait for the result
r = Responder(self.config, ipc, conn, req)
r.start()
# Join worker threads
self.retire_workers(self.workers)
# Clean up the PID file
remove_pid_file(self.pid_file)
# Be polite
logger.info("Good bye")
# Gracefully shut down logging
logging.shutdown()
class Worker(threading.Thread):
"""
Worker thread that processes requests (ie: queries the database)
Params:
config: (Config) The agent config object
queue: (Queue) The request queue the worker should pull requests from
"""
def __init__(self, config, queue):
super(Worker, self).__init__()
self.config = config
self.queue = queue
self.active = True
def run(self):
"""
Main processing loop for a worker thread
"""
while True:
# Wait for a request
try:
req = self.queue.get(timeout=1)
except queue.Empty:
req = None
# Check if we're supposed to exit
if not self.active:
# If we got a request, try to put it back on the queue
if req is not None:
try:
queue.put(req, timeout=1)
logger.info("Requeued request at worker exit")
except:
logger.warning("Failed to requeue request at worker exit")
req.set_result("ERROR: Failed to requeue at Worker exit")
logger.info("Worker exiting: tid={}".format(self.native_id))
break
# If we got here because we waited too long for a request, go back to waiting
if req is None:
continue
# Find the requested metrtic
try:
metric = self.config.metrics[req.metric_name]
except KeyError:
req.set_result("ERROR: Unknown metric: {}".format(req.metric_name))
continue
# Get the DB version
try:
pg_version = self.config.get_pg_version(req.cluster_name)
except Exception as e:
req.set_result("Failed to retrieve database version for cluster: {}".format(req.cluster_name))
logger.error("Failed to get Postgresql version for cluster {}: {}".format(req.cluster_name, e))
continue
# Get the query to use
try:
mv = metric.get_version(pg_version)
except UnsupportedMetricVersionError:
# Handle unsuported metric versions
req.set_result("Unsupported PosgreSQL version for metric")
continue
# Identify the database to query
try:
dbname = req.args['db']
except KeyError:
dbname = None
# Get any positional query args
try:
pos_args = req.args['pos'].split(':')
logger.debug("Found positional args for {}: {}".format(req.metric_name, ','.join(pos_args)))
except KeyError:
pos_args = []
logger.debug("No positional args found for {}".format(req.metric_name))
# Query the database
try:
db = DB(self.config, req.cluster_name, dbname)
res = db.query(mv.get_sql(req.args), pos_args)
db.close()
except Exception as e:
# Handle database errors
logger.error("Database error: {}".format(e))
# Make sure the database connectin is closed (ie: if the query timed out)
try:
db.close()
except:
pass
req.set_result("Failed to query database")
continue
# Set the result on the request
if metric.type == 'value':
if len(res) == 0:
req.set_result("Empty result set")
else:
req.set_result("{}".format(list(res[0].values())[0]))
elif metric.type == 'row':
req.set_result(json.dumps(res[0]))
elif metric.type == 'column':
req.set_result(json.dumps([list(r.values())[0] for r in res]))
elif metric.type == 'set':
req.set_result(json.dumps(res))
class Responder(threading.Thread):
"""
Thread responsible for replying to requests
Params:
config: (Config) The agent config object
ipc: (IPC) The IPC object used for communication
conn: (Socket) The connected socket to communicate with
req: (Request) The request object to handle
"""
def __init__(self, config, ipc, conn, req):
super(Responder, self).__init__()
# Wait for a result
result = req.get_result()
# Send the result back to the client
try:
ipc.send(conn, result)
except Exception as e:
logger.warning("Failed to reply to agent: {}".format(e))
def main():
# Set up command line argument parser
parser = argparse.ArgumentParser(
prog='pgmon',
description='A briidge between monitoring tools and PostgreSQL')
# General options
parser.add_argument('-c', '--config', default=None)
parser.add_argument('-v', '--verbose', action='store_true')
parser.add_argument('-s', '--socket', default=None)
parser.add_argument('-l', '--logfile', default=None)
parser.add_argument('-p', '--pidfile', default=None)
# Operational mode
parser.add_argument('-S', '--server', action='store_true')
parser.add_argument('-r', '--reload', action='store_true')
# Agent options
parser.add_argument('key', nargs='?')
# Parse command line arguments
args = parser.parse_args()
if args.server:
# Try to start running in server mode
try:
server = Server(args)
except Exception as e:
sys.exit("Failed to start server: {}".format(e))
try:
server.run()
except Exception as e:
sys.exit("Caught an unexpected runtime error: {}".format(e))
elif args.reload:
# Try to signal a running server to reload its config
try:
config = Config(args, read_metrics = False)
except Exception as e:
sys.exit("Failed to read config file: {}".format(e))
# Read the PID file
try:
pid = read_pid_file(config.pid_file)
except Exception as e:
sys.exit("Failed to read PID file: {}".format(e))
# Signal the server to reload
try:
os.kill(pid, signal.SIGHUP)
except:
sys.exit("Failed to signal server: {}".format(e))
else:
# Start running in agent mode
Agent.run(args, args.key)
if __name__ == '__main__':
main()
##
# Tests
##
class TestConfig:
pass
class TestDB:
pass
class TestRequest:
def test_request_creation(self):
# Create result with no args
req1 = Request('c1', 'foo', {})
assert req1.cluster_name == 'c1'
assert req1.metric_name == 'foo'
assert len(req1.args) == 0
assert req1.complete.locked()
# Create result with args
req2 = Request('c1', 'foo', {'arg1': 'value1', 'arg2': 'value2'})
assert req2.metric_name == 'foo'
assert len(req2.args) == 2
assert req2.complete.locked()
def test_request_lock(self):
req1 = Request('c1', 'foo', {})
assert req1.complete.locked()
req1.set_result('blah')
assert not req1.complete.locked()
assert 'blah' == req1.get_result()
class TestMetric:
def test_metric_creation(self):
# Test basic creation
m1 = Metric('foo', 'value')
assert m1.name == 'foo'
assert len(m1.versions) == 0
assert m1.type == 'value'
assert m1.cached is None
assert m1.cached_version is None
def test_metric_add_version(self):
# Test adding versions
m1 = Metric('foo', 'value')
assert len(m1.versions) == 0
m1.add_version(0, 'default')
assert len(m1.versions) == 1
m1.add_version(120003, 'v12.3')
assert len(m1.versions) == 2
# Make sure added versions are correct
assert m1.versions[0].sql == 'default'
assert m1.versions[120003].sql == 'v12.3'
def test_metric_get_version(self):
# Test retrieving metric versions
m1 = Metric('foo', 'value')
m1.add_version(100000, 'v10.0')
m1.add_version(120000, 'v12.0')
# Make sure cache is initially empty
assert m1.cached is None
assert m1.cached_version is None
assert m1.get_version(110003).sql == 'v10.0'
# Make sure cache is set
assert m1.cached is not None
assert m1.cached_version == 110003
# Make sure returned value changes with version
assert m1.get_version(120000).sql == 'v12.0'
assert m1.get_version(150005).sql == 'v12.0'
# Make sure an error is thrown when no version matches
with pytest.raises(UnsupportedMetricVersionError):
m1.get_version(90603)
# Add a default version
m1.add_version(0, 'default')
assert m1.get_version(90603).sql == 'default'
assert m1.get_version(110003).sql == 'v10.0'
assert m1.get_version(120000).sql == 'v12.0'
assert m1.get_version(150005).sql == 'v12.0'
class TestMetricVersion:
def test_metric_version_creation(self):
mv1 = MetricVersion('test')
assert mv1.sql == 'test'
def test_metric_version_templating(self):
mv1 = MetricVersion('foo')
assert mv1.get_sql({}) == 'foo'
mv2 = MetricVersion('foo {a1} {a3} {a2}')
assert mv2.get_sql({'a1': 'bar', 'a2': 'blah blah blah', 'a3': 'baz'}) == 'foo bar baz blah blah blah'
class TestIPC:
pass
class TestAgent:
pass
class TestServer:
pass
class TestWorker:
pass
class TestResponder:
pass