aboutsummaryrefslogtreecommitdiff
path: root/src/main.rs
blob: f284b754616155b1fb9dd64e6780e168e0bc2069 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use std::collections::HashMap;
use std::convert::Infallible;
use std::fs::read_to_string;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;

use hyper::service::{make_service_fn, service_fn};
use hyper::{header, Body, Request, Response, Server, StatusCode};
use serde::Deserialize;
use structopt::StructOpt;

mod utils;
use utils::HttpHandler;

#[derive(Debug, StructOpt)]
struct Opt {
    #[structopt(long, parse(from_os_str), default_value = "narchttpd.toml")]
    config_file: PathBuf,
}

#[derive(Deserialize)]
#[serde(tag = "mode", rename_all = "kebab-case")]
enum DomainConfig {
    Static {
        root: PathBuf,
    },

    ProxyChild {
        command: String,
        in_dir: Option<PathBuf>,
        port: u16,
    },
}

impl DomainConfig {
    fn handler(self) -> Box<dyn HttpHandler> {
        match self {
            Self::Static { root } => Box::new(utils::serve_static::Params::new(root)),
            Self::ProxyChild {
                command,
                in_dir,
                port,
            } => Box::new(utils::proxy_child::ProxyChild::new(command, in_dir, port)),
        }
    }
}

#[derive(Deserialize)]
struct Config {
    http_ports: Vec<u16>,
    https_ports: Vec<u16>,
    #[serde(flatten)]
    domains: HashMap<String, DomainConfig>,
}

#[tokio::main]
async fn main() {
    let opt = Opt::from_args();

    let config_data = read_to_string(&opt.config_file).expect("Config file not found");

    let Config {
        http_ports,
        https_ports,
        domains,
    } = toml::from_str(&config_data).expect("Config file not valid");

    assert!(https_ports.is_empty(), "HTTPS is complicated oops");

    let domains: HashMap<_, _> = domains
        .into_iter()
        .map(|(domain, config)| (domain, Arc::new(config.handler())))
        .collect();
    let do_response_domains = domains.clone();

    // TODO learn hyper
    let do_response = move |req: Request<Body>| async move {
        for (domain, handler) in do_response_domains.clone() {
            let req_domain = req.headers().get(header::HOST).unwrap();
            if &domain == req_domain {
                eprintln!("request {:?} matched domain {}", req, domain);
                // matched!
                return handler.handle(req).await;
            }
        }
        Response::builder()
            .status(StatusCode::NOT_FOUND)
            .body(Body::from("Not Found"))
            .unwrap()
    };
    let addr = http_ports.into_iter().flat_map(|port| {
        ["127.0.0.1", "::1"]
            .iter()
            .map(move |ip| SocketAddr::new(ip.parse().unwrap(), port))
    });

    let make_svc = make_service_fn(move |_conn| async move {
        Ok::<_, Infallible>(service_fn(move |req| async move {
            let response = do_response(req).await;
            Ok::<_, Infallible>(response)
        }))
    });

    // TODO uhh
    let mut addr = addr;
    let addr = addr.nth(0).unwrap();
    let server = Server::bind(&addr).serve(make_svc);

    if let Err(e) = server.await {
        eprintln!("server error: {}", e);
    }
}