Add standardised SQL connection setup for microservices.

This commit is contained in:
2026-04-25 21:42:07 +01:00
parent 3d164132ab
commit 76c63fc3ef
9 changed files with 325 additions and 39 deletions

View File

@@ -14,6 +14,8 @@ use aws_config::BehaviorVersion;
use aws_sdk_s3::Client;
use aws_sdk_s3::config::{Credentials, Region};
use aws_sdk_s3::primitives::ByteStream;
use diesel::Connection as DieselConnection;
use diesel::pg::PgConnection;
use futures_util::StreamExt;
use lapin::options::{
BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicPublishOptions,
@@ -89,6 +91,7 @@ pub struct Microservice {
}
impl Microservice {
#[cfg(feature = "python")]
fn new_case_key(
name: impl Into<String>,
config_host: impl Into<String>,
@@ -105,12 +108,16 @@ impl Microservice {
/// Create a new microservice runtime.
///
/// `process` accepts an inbound request ID, a `read_file` function, and a
/// `write_file` function, and then returns a list of
/// `write_file` function, and a mutable PostgreSQL connection, and then
/// returns a list of
/// `(result_id, case_variable)` tuples. Case variables can be any
/// serializable primitive, such as `String`, `bool`, or integers.
pub fn new<F, C>(name: impl Into<String>, config_host: impl Into<String>, process: F) -> Self
where
F: Fn(u64, &ReadFileFn, &WriteFileFn) -> Result<Vec<(u64, C)>, AnyError> + Send + Sync + 'static,
F: Fn(u64, &ReadFileFn, &WriteFileFn, &mut PgConnection) -> Result<Vec<(u64, C)>, AnyError>
+ Send
+ Sync
+ 'static,
C: Serialize + 'static,
{
init_tracing();
@@ -119,7 +126,13 @@ impl Microservice {
read_file: Arc<ReadFileFn>,
write_file: Arc<WriteFileFn>,
| -> Result<Vec<(u64, CaseKey)>, AnyError> {
let outputs = process(request, read_file.as_ref(), write_file.as_ref())?;
let mut connection = establish_pg_connection().map_err(|e| {
format!(
"failed to establish PostgreSQL connection for request {}: {}",
request, e
)
})?;
let outputs = process(request, read_file.as_ref(), write_file.as_ref(), &mut connection)?;
let mut mapped = Vec::with_capacity(outputs.len());
for (id, case) in outputs {
let value = serde_json::to_value(case)
@@ -148,7 +161,9 @@ impl Microservice {
.build()?;
let s3_client = runtime.block_on(fetch_s3_client_from_sys_map())?;
runtime.block_on(self.run_consumer(config.inbound, route_map, amqp_url, s3_client))
runtime
.block_on(self.run_consumer(config.inbound, route_map, amqp_url, s3_client))
.map_err(|e| format!("microservice '{}' failed: {}", self.name, e).into())
}
fn fetch_config(&self) -> Result<QueueConfig, AnyError> {
@@ -240,15 +255,23 @@ impl Microservice {
guard.write_file(&bucket, id)
});
let outputs = (self.process)(request_id, Arc::clone(&read_file), Arc::clone(&write_file))?;
let outputs = (self.process)(request_id, Arc::clone(&read_file), Arc::clone(&write_file))
.map_err(|e| format!("request {}: process callback failed: {}", request_id, e))?;
{
let mut guard = file_context
.lock()
.map_err(|e| format!("file context lock poisoned for finalize: {}", e))?;
guard.finalize(s3_client.as_ref())?
guard
.finalize(s3_client.as_ref())
.map_err(|e| format!("request {}: finalize/upload failed: {}", request_id, e))?
}
publish_outputs(&channel, outputs, &route_map).await?;
delivery.ack(BasicAckOptions::default()).await?;
publish_outputs(&channel, outputs, &route_map)
.await
.map_err(|e| format!("request {}: publish failed: {}", request_id, e))?;
delivery
.ack(BasicAckOptions::default())
.await
.map_err(|e| format!("request {}: ack failed: {}", request_id, e))?;
}
Ok(())
@@ -373,7 +396,16 @@ fn upload_to_s3(
.key(&key)
.body(body)
.send()
.await?;
.await
.map_err(|e| {
format!(
"failed to upload S3 object bucket='{}' key='{}' path='{}': {:?}",
bucket_name,
key,
path.display(),
e
)
})?;
Ok::<(), AnyError>(())
})
@@ -404,6 +436,15 @@ struct RabbitMqConfig {
pass: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct DatabaseConfig {
port: Vec<u16>,
host: Vec<String>,
username: Vec<String>,
database: Vec<String>,
pass: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct ObjectStorageConfig {
host: Vec<String>,
@@ -430,6 +471,42 @@ fn fetch_rabbitmq_url_from_sys_map() -> Result<String, AnyError> {
Ok(format!("amqp://{}:{}@{}:{}/%2f", username, pass, host, port))
}
fn fetch_database_url_from_sys_map() -> Result<String, AnyError> {
let config = if tokio::runtime::Handle::try_current().is_ok() {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
let response = reqwest::get("https://sys-map.slingshot.cv/db").await?;
let response = response.error_for_status()?;
response
.json::<DatabaseConfig>()
.await
.map_err(|e| -> AnyError { Box::new(e) })
})
})?
} else {
let response = reqwest::blocking::get("https://sys-map.slingshot.cv/db")?;
let response = response.error_for_status()?;
response.json::<DatabaseConfig>()?
};
let port = single_value(&config.port, "port")?;
let host = single_value(&config.host, "host")?;
let username = single_value(&config.username, "username")?;
let database = single_value(&config.database, "database")?;
let pass_key = single_value(&config.pass, "pass")?;
let password = resolve_password_from_pass(&pass_key)?;
info!(
"Fetched DB config from sys-map: host={}, port={}, username={}, database={}",
host, port, username, database
);
Ok(format!(
"postgres://{}:{}@{}:{}/{}",
username, password, host, port, database
))
}
async fn fetch_s3_client_from_sys_map() -> Result<Arc<Client>, AnyError> {
let response = reqwest::get("https://sys-map.slingshot.cv/object-storage").await?;
let response = response.error_for_status()?;
@@ -524,6 +601,12 @@ fn resolve_password_from_pass(pass_key: &str) -> Result<String, AnyError> {
Ok(password)
}
fn establish_pg_connection() -> Result<PgConnection, AnyError> {
let database_url = fetch_database_url_from_sys_map()?;
PgConnection::establish(&database_url)
.map_err(|e| format!("failed to connect to PostgreSQL using sys-map DB config: {}", e).into())
}
fn single_value<T: Clone>(values: &[T], field_name: &str) -> Result<T, AnyError> {
if values.len() != 1 {
return Err(format!(
@@ -633,9 +716,30 @@ async fn publish_outputs(
payload.as_bytes(),
BasicProperties::default(),
)
.await?;
confirm.await?;
.await
.map_err(|e| {
format!(
"failed to publish result_id={} to queue='{}': {}",
result_id,
queue,
e
)
})?;
confirm.await.map_err(|e| {
format!(
"broker rejected publish result_id={} queue='{}': {}",
result_id,
queue,
e
)
})?;
}
} else {
info!(
"No outbound queues configured for case variable {:?}; skipping publish for result_id={}",
case_var,
result_id
);
}
}
@@ -745,33 +849,54 @@ fn run_python_process(
read_file: Arc<ReadFileFn>,
write_file: Arc<WriteFileFn>,
) -> Result<Vec<(u64, CaseKey)>, AnyError> {
let database_url = fetch_database_url_from_sys_map()?;
Python::with_gil(|py| -> Result<Vec<(u64, CaseKey)>, AnyError> {
let py_read = Py::new(py, PyReadFileFn { inner: read_file })
.map_err(|e| format!("failed to build Python ReadFileFn wrapper: {}", e))?;
let py_write = Py::new(py, PyWriteFileFn { inner: write_file })
.map_err(|e| format!("failed to build Python WriteFileFn wrapper: {}", e))?;
let sqlalchemy = py
.import("sqlalchemy")
.map_err(|e| format!("failed to import sqlalchemy: {}", e))?;
let engine = sqlalchemy
.getattr("create_engine")
.and_then(|f| f.call1((database_url.as_str(),)))
.map_err(|e| format!("failed to create SQLAlchemy engine: {}", e))?;
let connection = engine
.call_method0("connect")
.map_err(|e| format!("failed to open SQLAlchemy connection: {}", e))?;
let returned = process
.call1(py, (request, py_read, py_write))
.map_err(|e| format!("Python process callback failed: {}", e))?;
let outputs_result = (|| -> Result<Vec<(u64, CaseKey)>, AnyError> {
let returned = process
.call1(py, (request, py_read, py_write, connection.clone()))
.map_err(|e| format!("Python process callback failed: {}", e))?;
let iter = returned
.bind(py)
.try_iter()
.map_err(|e| format!("process return value must be iterable: {}", e))?;
let iter = returned
.bind(py)
.try_iter()
.map_err(|e| format!("process return value must be iterable: {}", e))?;
let mut outputs = Vec::new();
for item in iter {
let item = item.map_err(|e| format!("failed to iterate process outputs: {}", e))?;
let (id, case_obj): (u64, Py<PyAny>) = item
.extract()
.map_err(|e| format!("each output must be a tuple (int, case): {}", e))?;
let case = case_key_from_py_value(case_obj.bind(py))
.map_err(|e| format!("invalid case variable: {}", e))?;
outputs.push((id, case));
let mut outputs = Vec::new();
for item in iter {
let item = item.map_err(|e| format!("failed to iterate process outputs: {}", e))?;
let (id, case_obj): (u64, Py<PyAny>) = item
.extract()
.map_err(|e| format!("each output must be a tuple (int, case): {}", e))?;
let case = case_key_from_py_value(case_obj.bind(py))
.map_err(|e| format!("invalid case variable: {}", e))?;
outputs.push((id, case));
}
Ok(outputs)
})();
let close_result = connection.call_method0("close");
if let Err(err) = close_result {
return Err(format!("failed to close SQLAlchemy connection: {}", err).into());
}
Ok(outputs)
outputs_result
})
}