use core::cmp::{max, min};
#[cfg(feature = "std")]
use std::sync::Arc;
use crate::channel::{self, ChannelError};
use crate::error::{Error, ErrorCode};
use crate::internal::fuse_io;
use crate::internal::fuse_kernel;
use crate::internal::types::ProtocolVersion;
use crate::protocol::common::{RequestHeader, UnknownRequest};
pub trait ServerChannel: channel::Channel {
fn try_clone(&self) -> Result<Self, Self::Error>
where
Self: Sized;
}
pub struct ServerContext {
header: fuse_kernel::fuse_in_header,
}
impl<'a> ServerContext {
pub(crate) fn new(header: fuse_kernel::fuse_in_header) -> Self {
Self { header }
}
pub fn request_header(&self) -> &RequestHeader {
RequestHeader::new_ref(&self.header)
}
}
#[allow(unused_variables)]
pub trait ServerHooks {
fn request(&self, request_header: &RequestHeader) {}
fn unknown_request(&self, request: &UnknownRequest) {}
fn unhandled_request(&self, request_header: &RequestHeader) {}
fn request_error(&self, request_header: &RequestHeader, err: Error) {}
fn response_error(
&self,
request_header: &RequestHeader,
code: Option<ErrorCode>,
) {
}
fn async_channel_error(
&self,
request_header: &RequestHeader,
code: Option<ErrorCode>,
) {
}
}
#[cfg_attr(not(feature = "std"), allow(dead_code))]
pub enum NoopServerHooks {}
impl ServerHooks for NoopServerHooks {}
const HEADER_OVERHEAD: usize = 4096;
#[cfg_attr(not(feature = "std"), allow(dead_code))]
pub(crate) fn read_buf_size(max_write: u32) -> usize {
let max_write = max_write as usize;
max(
HEADER_OVERHEAD + max_write,
fuse_kernel::FUSE_MIN_READ_BUFFER,
)
}
#[cfg(not(feature = "std"))]
pub(crate) const fn capped_max_write() -> u32 {
return (fuse_kernel::FUSE_MIN_READ_BUFFER - HEADER_OVERHEAD) as u32;
}
pub(crate) fn negotiate_version(
kernel: ProtocolVersion,
) -> Option<ProtocolVersion> {
if kernel.major() != fuse_kernel::FUSE_KERNEL_VERSION {
return None;
}
Some(ProtocolVersion::new(
fuse_kernel::FUSE_KERNEL_VERSION,
min(kernel.minor(), fuse_kernel::FUSE_KERNEL_MINOR_VERSION),
))
}
pub(crate) fn main_loop<Buf, C, Cb>(
channel: &C,
read_buf: &mut Buf,
fuse_version: ProtocolVersion,
semantics: fuse_io::Semantics,
cb: Cb,
) -> Result<(), C::Error>
where
Buf: fuse_io::AlignedBuffer,
C: channel::Channel,
Cb: Fn(fuse_io::RequestDecoder) -> Result<(), C::Error>,
{
loop {
let request_size = match channel.receive(read_buf.get_mut()) {
Err(err) => {
if semantics == fuse_io::Semantics::FUSE {
if err.error_code() == Some(ErrorCode::ENODEV) {
return Ok(());
}
}
return Err(err);
},
Ok(request_size) => request_size,
};
let request_buf = fuse_io::aligned_slice(read_buf, request_size);
cb(fuse_io::RequestDecoder::new(
request_buf,
fuse_version,
semantics,
)?)?;
}
}
pub(crate) trait MaybeSendChannel {
#[cfg(feature = "std")]
type T: channel::Channel + Send + Sync + 'static;
#[cfg(not(feature = "std"))]
type T: channel::Channel;
}
#[cfg(feature = "std")]
impl<C> MaybeSendChannel for C
where
C: channel::Channel + Send + Sync + 'static,
{
type T = C;
}
#[cfg(not(feature = "std"))]
impl<C> MaybeSendChannel for C
where
C: channel::Channel,
{
type T = C;
}
pub(crate) trait MaybeSendHooks {
#[cfg(feature = "std")]
type T: ServerHooks + Send + Sync + 'static;
#[cfg(not(feature = "std"))]
type T: ServerHooks;
}
#[cfg(feature = "std")]
impl<H> MaybeSendHooks for H
where
H: ServerHooks + Send + Sync + 'static,
{
type T = H;
}
#[cfg(not(feature = "std"))]
impl<H> MaybeSendHooks for H
where
H: ServerHooks,
{
type T = H;
}
mod private {
pub trait Respond {
type Internal: RespondInternal<Self>;
}
pub trait RespondInternal<R: ?Sized> {
fn unhandled_request(r: &R);
}
}
pub(crate) fn unhandled_request<T, R: Respond<T>>(respond: R) {
use private::RespondInternal;
R::Internal::unhandled_request(&respond);
respond.err(ErrorCode::ENOSYS)
}
pub trait Respond<R>: private::Respond {
fn ok(self, response: &R);
fn err(self, err: ErrorCode);
#[cfg(feature = "std")]
#[cfg_attr(doc, doc(cfg(feature = "std")))]
fn into_async(self) -> RespondAsync<R>;
}
pub(crate) struct RespondRef<'a, C, Hooks>
where
C: channel::Channel,
{
channel: &'a C,
hooks: Option<&'a Hooks>,
channel_err: &'a mut Result<(), C::Error>,
header: &'a RequestHeader,
fuse_version: ProtocolVersion,
#[cfg(feature = "std")]
channel_arc: &'a Arc<C>,
#[cfg(feature = "std")]
hooks_arc: Option<&'a Arc<Hooks>>,
}
impl<'a, C, Hooks> RespondRef<'a, C, Hooks>
where
C: channel::Channel,
Hooks: ServerHooks,
{
pub(crate) fn new(
channel: &'a C,
hooks: Option<&'a Hooks>,
channel_err: &'a mut Result<(), C::Error>,
header: &'a RequestHeader,
fuse_version: ProtocolVersion,
#[cfg(feature = "std")] channel_arc: &'a Arc<C>,
#[cfg(feature = "std")] hooks_arc: Option<&'a Arc<Hooks>>,
) -> Self {
Self {
channel,
hooks,
channel_err,
header,
fuse_version,
#[cfg(feature = "std")]
channel_arc,
#[cfg(feature = "std")]
hooks_arc,
}
}
pub(crate) fn encoder(&self) -> fuse_io::ResponseEncoder<C> {
fuse_io::ResponseEncoder::new(
self.channel,
self.header.request_id(),
self.fuse_version,
)
}
fn ok_impl<R>(self, response: &R)
where
R: fuse_io::EncodeResponse,
{
if let Err(err) = response.encode_response(self.encoder()) {
if let Some(hooks) = &self.hooks {
hooks.response_error(self.header, err.error_code())
}
self.err_impl(ErrorCode::EIO);
}
}
pub(crate) fn err_impl(self, err: ErrorCode) {
*self.channel_err = self.encoder().encode_error(err);
}
}
impl<C, Hooks> private::Respond for RespondRef<'_, C, Hooks>
where
C: channel::Channel,
Hooks: ServerHooks,
{
type Internal = RespondRefInternal;
}
pub struct RespondRefInternal(());
impl<C, Hooks> private::RespondInternal<RespondRef<'_, C, Hooks>>
for RespondRefInternal
where
C: channel::Channel,
Hooks: ServerHooks,
{
fn unhandled_request(r: &RespondRef<C, Hooks>) {
if let Some(hooks) = r.hooks {
hooks.unhandled_request(r.header);
}
}
}
#[cfg(feature = "std")]
impl<C, Hooks, R> Respond<R> for RespondRef<'_, C, Hooks>
where
C: channel::Channel + Send + Sync + 'static,
Hooks: ServerHooks + Send + Sync + 'static,
R: fuse_io::EncodeResponse,
{
fn ok(self, response: &R) {
self.ok_impl(response)
}
fn err(self, err: ErrorCode) {
self.err_impl(err)
}
fn into_async(self) -> RespondAsync<R> {
self.new_respond_async()
}
}
#[cfg(not(feature = "std"))]
impl<C, Hooks, R> Respond<R> for RespondRef<'_, C, Hooks>
where
C: channel::Channel,
Hooks: ServerHooks,
R: fuse_io::EncodeResponse,
{
fn ok(self, response: &R) {
self.ok_impl(response)
}
fn err(self, err: ErrorCode) {
self.err_impl(err)
}
}
#[cfg(feature = "std")]
#[cfg_attr(doc, doc(cfg(feature = "std")))]
pub struct RespondAsync<R>(Box<dyn RespondAsyncInner<R> + 'static>);
#[cfg(feature = "std")]
impl<R> RespondAsync<R> {
pub fn ok(self, response: &R) {
self.0.ok(response)
}
pub fn err(self, err: ErrorCode) {
self.0.err(err)
}
}
#[cfg(feature = "std")]
trait RespondAsyncInner<R>: Send + Sync {
fn ok(&self, response: &R);
fn err(&self, err: ErrorCode);
}
#[cfg(feature = "std")]
struct RespondAsyncInnerImpl<C, Hooks> {
channel: Arc<C>,
hooks: Option<Arc<Hooks>>,
header: RequestHeader,
fuse_version: ProtocolVersion,
}
#[cfg(feature = "std")]
impl<C, Hooks> RespondAsyncInnerImpl<C, Hooks>
where
C: channel::Channel,
Hooks: ServerHooks,
{
fn encoder(&self) -> fuse_io::ResponseEncoder<C> {
fuse_io::ResponseEncoder::new(
self.channel.as_ref(),
self.header.request_id(),
self.fuse_version,
)
}
fn err_impl(&self, err: ErrorCode) {
if let Err(err) = self.encoder().encode_error(err) {
if let Some(hooks) = &self.hooks {
hooks.async_channel_error(&self.header, err.error_code())
}
}
}
}
#[cfg(feature = "std")]
impl<C, Hooks, R> RespondAsyncInner<R> for RespondAsyncInnerImpl<C, Hooks>
where
C: channel::Channel + Send + Sync,
Hooks: ServerHooks + Send + Sync,
R: fuse_io::EncodeResponse,
{
fn ok(&self, response: &R) {
if let Err(err) = response.encode_response(self.encoder()) {
if let Some(hooks) = &self.hooks {
hooks.response_error(&self.header, err.error_code())
}
self.err_impl(ErrorCode::EIO)
}
}
fn err(&self, err: ErrorCode) {
self.err_impl(err)
}
}
#[cfg(feature = "std")]
impl<'a, C, Hooks> RespondRef<'a, C, Hooks>
where
C: channel::Channel + Send + Sync + 'static,
Hooks: ServerHooks + Send + Sync + 'static,
{
fn new_respond_async<R>(self) -> RespondAsync<R>
where
R: fuse_io::EncodeResponse,
{
RespondAsync(Box::new(RespondAsyncInnerImpl {
channel: self.channel_arc.clone(),
hooks: self.hooks_arc.map(|h| h.clone()),
header: self.header.clone(),
fuse_version: self.fuse_version,
}))
}
}