diff --git a/pyproject.toml b/pyproject.toml index 7160052..6830bf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "slingshot-microservice" -version = "0.1.3" +version = "0.1.4" description = "Opinionated Rust framework for queue-driven microservices" license = { text = "MIT" } requires-python = ">=3.8" diff --git a/slingshot_microservice/typing.py b/slingshot_microservice/typing.py index 26cd2c9..2ff2b90 100644 --- a/slingshot_microservice/typing.py +++ b/slingshot_microservice/typing.py @@ -15,7 +15,12 @@ class ReadFileFn(Protocol): class WriteFileFn(Protocol): - def __call__(self, key: str, id: int) -> BinaryIO: ... + def __call__(self, key: str, id: int) -> "WriteHandle": ... + + +class WriteHandle(Protocol): + def write(self, data: bytes | str) -> int: ... + def flush(self) -> None: ... class ProcessFn(Protocol): @@ -31,6 +36,7 @@ class ProcessFn(Protocol): __all__ = [ "CaseVariable", "ReadFileFn", + "WriteHandle", "WriteFileFn", "ProcessFn", ] diff --git a/src/lib.rs b/src/lib.rs index ea6219d..3d97764 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -802,14 +802,23 @@ struct PyWriteHandle { #[cfg(feature = "python")] #[pymethods] impl PyWriteHandle { - fn write(&self, data: &[u8]) -> PyResult { + fn write(&self, data: &Bound<'_, PyAny>) -> PyResult { + let (buffer, chars_written) = if let Ok(bytes) = data.extract::>() { + (bytes, None) + } else if let Ok(text) = data.extract::() { + let char_count = text.chars().count(); + (text.into_bytes(), Some(char_count)) + } else { + return Err(PyTypeError::new_err("write() argument must be bytes or str")); + }; + let mut file = self .file .lock() .map_err(|e| PyRuntimeError::new_err(format!("write lock poisoned: {}", e)))?; - file.write_all(data) + file.write_all(&buffer) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - Ok(data.len()) + Ok(chars_written.unwrap_or(buffer.len())) } fn flush(&self) -> PyResult<()> {