Implement a startup script which runs once on microservice start.

This commit is contained in:
Seán Healy
2026-05-08 20:45:49 +01:00
parent 13b82d66c6
commit c823a73bc4
6 changed files with 246 additions and 21 deletions

View File

@@ -9,6 +9,9 @@ assumptions about a microservice:
1. A microservice listens to incoming requests on its own dedicated and
singular queue (RabbitMQ).
2. Incoming requests are in the form of a 64-bit unsigned integer (`u64`).
2. Microservices can run optional startup work via a `startup` function, which
takes three arguments: a `read_file` function, a `write_file` function, and
a database ORM `connection`.
2. Microservices process requests via a `process` function, which takes four
arguments: the incoming request (`u64`), a `read_file` function, a
`write_file` function, and a database ORM `connection`.
@@ -85,6 +88,15 @@ from slingshot_microservice.typing import ReadFileFn, WriteFileFn
from slingshot_microservice import Microservice
def startup(
read_file: ReadFileFn,
write_file: WriteFileFn,
connection: Connection,
) -> Generator[tuple[int, bool | int | str], None, None]:
if False:
yield (0, True)
def process(
request: int,
read_file: ReadFileFn,
@@ -100,7 +112,12 @@ def process(
yield (request, True)
microservice = Microservice("simple-py-microservice", "sys-map.slingshot.cv", process)
microservice = Microservice(
"simple-py-microservice",
"sys-map.slingshot.cv",
startup,
process,
)
microservice.start()
```
@@ -113,7 +130,8 @@ editors and type-checkers:
|---|---|
| `ReadFileFn` | Callable returned by `read_file(key, id)` behaves like `BinaryIO` |
| `WriteFileFn` | Callable returned by `write_file(key, id)` behaves like `BinaryIO` |
| `ProcessFn` | The generator signature expected by `Microservice` with `(request, read_file, write_file, connection)` |
| `StartupFn` | The generator signature expected by `Microservice` startup hook with `(read_file, write_file, connection)` |
| `ProcessFn` | The generator signature expected by `Microservice` process callback with `(request, read_file, write_file, connection)` |
| `CaseVariable` | `bool \| int \| str` valid case variable types |
### Publishing Wheels
@@ -154,11 +172,20 @@ fn process(
Ok(vec![(request, "case_a".to_string())])
}
fn startup(
read_file: &ReadFileFn,
write_file: &WriteFileFn,
connection: &mut PgConnection,
) -> Result<Vec<(u64, String)>, AnyError> {
Ok(Vec::new())
}
fn main() {
// Create a new microservice instance with the processing function
// Create a new microservice instance with startup and process functions
let microservice = Microservice::new(
"simple-microservice",
"sys-map.example.com",
startup,
process
);
@@ -220,7 +247,11 @@ actual secrets with `pass show <key>` before constructing the S3 client.
When the microservice first starts up, it makes a request to the configuration
service to get the queue metadata. Then it starts to listen to the inbound
queue. Inbound requests are processed by the user-programmed `process`
queue. Before consuming inbound requests, the user-programmed `startup`
function runs once with `(read_file, write_file, connection)` and may return
zero or more `(result_id, case_variable)` tuples for normal outbound routing.
Inbound requests are then processed by the user-programmed `process`
function, which is called with `(request, read_file, write_file, connection)`
and returns a set of tuples of the form `(result_id, case_variable)`.
@@ -247,6 +278,10 @@ The output queue routing step looks like this:
Peudocode:
```
for each (result_id, case_variable) in startup(read_file, write_file, connection):
for each outbound_queue in config.out[case_variable]:
send result_id to outbound_queue
for each (result_id, case_variable) in process(request, read_file, write_file, connection):
for each outbound_queue in config.out[case_variable]:
send result_id to outbound_queue

View File

@@ -3,6 +3,15 @@ from slingshot_microservice import Microservice
from typing import Generator
from sqlalchemy.engine.base import Connection
def startup(
read_file: ReadFileFn,
write_file: WriteFileFn,
_connection: Connection,
) -> Generator[tuple[int, bool | int | str], None, None]:
if False:
yield (0, True)
def process(
request: int,
read_file: ReadFileFn,
@@ -16,5 +25,5 @@ def process(
yield (request, True)
microservice = Microservice("simple-py-microservice", "sys-map.slingshot.cv", process)
microservice = Microservice("simple-py-microservice", "sys-map.slingshot.cv", startup, process)
microservice.start()

View File

@@ -18,8 +18,21 @@ fn process(
Ok(vec![(request, "case_a".to_string())])
}
fn startup(
_read_file: &ReadFileFn,
_write_file: &WriteFileFn,
_connection: &mut PgConnection,
) -> Result<Vec<(u64, String)>, AnyError> {
Ok(Vec::new())
}
fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let microservice = Microservice::new("simple-microservice", "sys-map.slingshot.cv", process);
let microservice = Microservice::new(
"simple-microservice",
"sys-map.slingshot.cv",
startup,
process,
);
microservice.start()?;
Ok(())

View File

@@ -4,7 +4,7 @@ build-backend = "maturin"
[project]
name = "slingshot-microservice"
version = "0.1.5"
version = "0.1.6"
description = "Opinionated Rust framework for queue-driven microservices"
license = { text = "MIT" }
requires-python = ">=3.8"

View File

@@ -33,10 +33,20 @@ class ProcessFn(Protocol):
) -> Generator[tuple[int, CaseVariable], None, None]: ...
class StartupFn(Protocol):
def __call__(
self,
read_file: ReadFileFn,
write_file: WriteFileFn,
connection: SqlAlchemyConnection,
) -> Generator[tuple[int, CaseVariable], None, None]: ...
__all__ = [
"CaseVariable",
"ReadFileFn",
"WriteHandle",
"WriteFileFn",
"ProcessFn",
"StartupFn",
]

View File

@@ -49,6 +49,12 @@ type ProcessFn = dyn Fn(u64, Arc<ReadFileFn>, Arc<WriteFileFn>) -> Result<Vec<(u
+ Sync
+ 'static;
type StartupFn =
dyn Fn(Arc<ReadFileFn>, Arc<WriteFileFn>) -> Result<Vec<(u64, CaseKey)>, AnyError>
+ Send
+ Sync
+ 'static;
static REQUEST_FILE_CONTEXT_COUNTER: AtomicU64 = AtomicU64::new(1);
static DATABASE_URL_CACHE: OnceLock<String> = OnceLock::new();
static DATABASE_POOL_CACHE: OnceLock<Pool<ConnectionManager<PgConnection>>> = OnceLock::new();
@@ -90,6 +96,7 @@ enum CaseKey {
pub struct Microservice {
name: String,
config_host: String,
startup: Arc<StartupFn>,
process: Arc<ProcessFn>,
}
@@ -98,32 +105,60 @@ impl Microservice {
fn new_case_key(
name: impl Into<String>,
config_host: impl Into<String>,
startup: Arc<StartupFn>,
process: Arc<ProcessFn>,
) -> Self {
init_tracing();
Self {
name: name.into(),
config_host: config_host.into(),
startup,
process,
}
}
/// Create a new microservice runtime.
///
/// `process` accepts an inbound request ID, a `read_file` function, and a
/// `write_file` function, and a mutable PostgreSQL connection, and then
/// returns a list of
/// `startup` runs once when the service starts, before queue consumption,
/// and receives `read_file`, `write_file`, and a mutable PostgreSQL
/// connection.
///
/// `process` accepts an inbound request ID, `read_file`, `write_file`, and
/// a mutable PostgreSQL connection, and then returns a list of
/// `(result_id, case_variable)` tuples. Case variables can be any
/// serializable primitive, such as `String`, `bool`, or integers.
pub fn new<F, C>(name: impl Into<String>, config_host: impl Into<String>, process: F) -> Self
pub fn new<F, S, C>(
name: impl Into<String>,
config_host: impl Into<String>,
startup: S,
process: F,
) -> Self
where
F: Fn(u64, &ReadFileFn, &WriteFileFn, &mut PgConnection) -> Result<Vec<(u64, C)>, AnyError>
+ Send
+ Sync
+ 'static,
S: Fn(&ReadFileFn, &WriteFileFn, &mut PgConnection) -> Result<Vec<(u64, C)>, AnyError>
+ Send
+ Sync
+ 'static,
C: Serialize + 'static,
{
init_tracing();
let startup_wrapper = move |
read_file: Arc<ReadFileFn>,
write_file: Arc<WriteFileFn>,
| -> Result<Vec<(u64, CaseKey)>, AnyError> {
let mut connection = establish_pg_connection().map_err(|e| {
format!(
"failed to checkout PostgreSQL connection from pool for startup callback: {}",
e
)
})?;
let outputs = startup(read_file.as_ref(), write_file.as_ref(), &mut connection)?;
map_outputs_with_case_key(outputs)
};
let process_wrapper = move |
request: u64,
read_file: Arc<ReadFileFn>,
@@ -136,19 +171,13 @@ impl Microservice {
)
})?;
let outputs = process(request, read_file.as_ref(), write_file.as_ref(), &mut connection)?;
let mut mapped = Vec::with_capacity(outputs.len());
for (id, case) in outputs {
let value = serde_json::to_value(case)
.map_err(|e| format!("case variable must be serializable to JSON: {}", e))?;
mapped.push((id, case_key_from_value(&value)?));
}
Ok(mapped)
map_outputs_with_case_key(outputs)
};
Self {
name: name.into(),
config_host: config_host.into(),
startup: Arc::new(startup_wrapper),
process: Arc::new(process_wrapper),
}
}
@@ -190,6 +219,61 @@ impl Microservice {
declare_queues(&channel, &inbound_queue, &route_map).await?;
let startup_request_id = 0_u64;
let startup_context = Arc::new(Mutex::new(RequestFileContext::new(startup_request_id)?));
let startup_read_context = Arc::clone(&startup_context);
let startup_write_context = Arc::clone(&startup_context);
let startup_s3_read_client = Arc::clone(&s3_client);
let startup_read_bucket_cache = Arc::clone(&bucket_name_cache);
let startup_write_bucket_cache = Arc::clone(&bucket_name_cache);
let startup_config_host = self.config_host.clone();
let startup_microservice_name = self.name.clone();
let startup_read_config_host = startup_config_host.clone();
let startup_read_microservice_name = startup_microservice_name.clone();
let startup_read_file: Arc<ReadFileFn> = Arc::new(move |key: &str, id: u64| -> Result<ReadFile, AnyError> {
let bucket = resolve_bucket_name(
&startup_read_config_host,
&startup_read_microservice_name,
&startup_read_bucket_cache,
key,
)?;
let mut guard = startup_read_context
.lock()
.map_err(|e| format!("file context lock poisoned for startup read_file: {}", e))?;
guard.read_file(startup_s3_read_client.as_ref(), &bucket, id)
});
let startup_write_file: Arc<WriteFileFn> = Arc::new(move |key: &str, id: u64| -> Result<File, AnyError> {
let bucket = resolve_bucket_name(
&startup_config_host,
&startup_microservice_name,
&startup_write_bucket_cache,
key,
)?;
let mut guard = startup_write_context
.lock()
.map_err(|e| format!("file context lock poisoned for startup write_file: {}", e))?;
guard.write_file(&bucket, id)
});
let startup_outputs = (self.startup)(
Arc::clone(&startup_read_file),
Arc::clone(&startup_write_file),
)
.map_err(|e| format!("startup callback failed: {}", e))?;
{
let mut guard = startup_context
.lock()
.map_err(|e| format!("file context lock poisoned for startup finalize: {}", e))?;
guard
.finalize(s3_client.as_ref())
.map_err(|e| format!("startup finalize/upload failed: {}", e))?
}
publish_outputs(&channel, startup_outputs, &route_map)
.await
.map_err(|e| format!("startup publish failed: {}", e))?;
let mut consumer = channel
.basic_consume(
&inbound_queue,
@@ -700,6 +784,17 @@ fn case_key_from_value(case_value: &Value) -> Result<CaseKey, AnyError> {
}
}
fn map_outputs_with_case_key<C: Serialize>(outputs: Vec<(u64, C)>) -> Result<Vec<(u64, CaseKey)>, AnyError> {
let mut mapped = Vec::with_capacity(outputs.len());
for (id, case) in outputs {
let value = serde_json::to_value(case)
.map_err(|e| format!("case variable must be serializable to JSON: {}", e))?;
mapped.push((id, case_key_from_value(&value)?));
}
Ok(mapped)
}
async fn declare_queues(
channel: &Channel,
inbound_queue: &str,
@@ -955,6 +1050,7 @@ fn create_python_sqlalchemy_engine() -> Result<Py<PyAny>, AnyError> {
struct PyMicroservice {
name: String,
config_host: String,
startup: Py<PyAny>,
process: Py<PyAny>,
}
@@ -962,8 +1058,12 @@ struct PyMicroservice {
#[pymethods]
impl PyMicroservice {
#[new]
fn new(name: String, config_host: String, process: Py<PyAny>) -> PyResult<Self> {
fn new(name: String, config_host: String, startup: Py<PyAny>, process: Py<PyAny>) -> PyResult<Self> {
Python::with_gil(|py| {
if !startup.bind(py).is_callable() {
return Err(PyTypeError::new_err("startup must be callable"));
}
if !process.bind(py).is_callable() {
return Err(PyTypeError::new_err("process must be callable"));
}
@@ -971,19 +1071,26 @@ impl PyMicroservice {
Ok(Self {
name,
config_host,
startup,
process,
})
})
}
fn start(&self) -> PyResult<()> {
let startup = Python::with_gil(|py| self.startup.clone_ref(py));
let process = Python::with_gil(|py| self.process.clone_ref(py));
let engine = create_python_sqlalchemy_engine().map_err(any_error_to_py)?;
let startup_engine = Python::with_gil(|py| engine.clone_ref(py));
let process_engine = Python::with_gil(|py| engine.clone_ref(py));
let microservice = Microservice::new_case_key(
self.name.clone(),
self.config_host.clone(),
Arc::new(move |read_file, write_file| {
run_python_startup(&startup, &startup_engine, read_file, write_file)
}),
Arc::new(move |request, read_file, write_file| {
run_python_process(&process, &engine, request, read_file, write_file)
run_python_process(&process, &process_engine, request, read_file, write_file)
}),
);
@@ -991,6 +1098,57 @@ impl PyMicroservice {
}
}
#[cfg(feature = "python")]
fn run_python_startup(
startup: &Py<PyAny>,
engine: &Py<PyAny>,
read_file: Arc<ReadFileFn>,
write_file: Arc<WriteFileFn>,
) -> Result<Vec<(u64, CaseKey)>, AnyError> {
Python::with_gil(|py| -> Result<Vec<(u64, CaseKey)>, AnyError> {
let py_read = Py::new(py, PyReadFileFn { inner: read_file })
.map_err(|e| format!("failed to build Python ReadFileFn wrapper: {}", e))?;
let py_write = Py::new(py, PyWriteFileFn { inner: write_file })
.map_err(|e| format!("failed to build Python WriteFileFn wrapper: {}", e))?;
let connection = engine
.bind(py)
.call_method0("connect")
.map_err(|e| format!("failed to open SQLAlchemy connection: {}", e))?;
let outputs_result = (|| -> Result<Vec<(u64, CaseKey)>, AnyError> {
let returned = startup
.call1(py, (py_read, py_write, connection.clone()))
.map_err(|e| format!("Python startup callback failed: {}", e))?;
let iter = returned
.bind(py)
.try_iter()
.map_err(|e| format!("startup return value must be iterable: {}", e))?;
let mut outputs = Vec::new();
for item in iter {
let item = item.map_err(|e| format!("failed to iterate startup outputs: {}", e))?;
let (id, case_obj): (u64, Py<PyAny>) = item
.extract()
.map_err(|e| format!("each startup output must be a tuple (int, case): {}", e))?;
let case = case_key_from_py_value(case_obj.bind(py))
.map_err(|e| format!("invalid case variable: {}", e))?;
outputs.push((id, case));
}
Ok(outputs)
})();
let close_result = connection.call_method0("close");
if let Err(err) = close_result {
return Err(format!("failed to close SQLAlchemy connection: {}", err).into());
}
outputs_result
})
}
#[cfg(feature = "python")]
#[pymodule]
fn _native(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> {