Echo Writes Code

configuration.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
use crate::errors::{ Result, Error };

use clap::{ Parser };
use config::{ Config };
use tracing_subscriber::layer::{ SubscriberExt };
use tracing_subscriber::util::{ SubscriberInitExt };

use std::path::{ PathBuf };
use std::str::{ FromStr };

macro_rules! version_with_hash {
	() => {
		format!("{} (hash {})", env!("CARGO_PKG_VERSION"), env!("GIT_HASH"))
	}
}

#[derive(Parser)]
#[command(author, version = version_with_hash!(), about)]
struct Arguments {
	/// Additional configuration files to load (may be passed multiple times)
	#[arg(id="configuration_file", short, long)]
	configuration_files: Vec<String>,

	/// Additional configuration parameters to use (may be passed multiple times; overrides configuration files)
	#[arg(id="parameter", short, long)]
	configuration_parameters: Vec<ConfigurationParameter>,
}

#[derive(Clone)]
struct ConfigurationParameter(String, String);

impl FromStr for ConfigurationParameter {
	type Err = clap::Error;

	fn from_str(s: &str) -> ::std::result::Result<ConfigurationParameter, clap::Error> {
		let parts: Vec<&str> = s.splitn(2, '=').collect();

		if parts.len() != 2 {
			return Err(clap::Error::new(clap::error::ErrorKind::WrongNumberOfValues));
		}

		Ok(ConfigurationParameter(parts[0].to_string(), parts[1].to_string()))
	}
}

pub(crate) struct Configuration {
	pub(crate) content_metadata_file: String,
	pub(crate) content_root: PathBuf,
	pub(crate) email_from: String,
	pub(crate) email_live: bool,
	pub(crate) email_smtp_host: String,
	pub(crate) email_smtp_port: u16,
	pub(crate) network_host: String,
	pub(crate) network_port: u16,
}

pub(crate) fn configure() -> Result<Configuration> {
	// Configure tracing (aka logging)
	tracing_subscriber::registry()
		.with(tracing_subscriber::fmt::layer())
		.init();

	// Parse command line arguments
	let arguments = Arguments::parse();

	// Set up default configuration values
	let mut builder = Config::builder()
		.set_default("content.metadata_file", "metadata.json")?
		.set_default("email.from", "web-pylon@example.com")?
		.set_default("email.live", false)?
		.set_default("email.smtp_host", "localhost")?
		.set_default("email.smtp_port", 25)?
		.set_default("network.host", "localhost")?
		.set_default("network.port", 8080)?;

	// Prepare the list of configuration files to read
	if !arguments.configuration_files.is_empty() {
		for path in arguments.configuration_files {
			tracing::debug!("Found configuration file: {}", path);
			builder = builder.add_source(config::File::with_name(&path));
		}
	} else {
		tracing::warn!("No configuration files detected; using default settings");
	}

	// Apply configuration parameters specified on the command line
	for parameter in arguments.configuration_parameters {
		tracing::debug!("Overriding parameter {}: {}", parameter.0, parameter.1);
		builder = builder.set_override(parameter.0, parameter.1)?;
	}

	// Actually read the configuration files
	let config = builder.build()?;

	// Perform any per-parameter validation
	let email_smtp_port =
		match config.get_int("email.smtp_port")?.try_into() {
			Ok(port) => Ok(port),
			Err(e) => Err(Error::InvalidPort(e)),
		}?;

	let network_port =
		match config.get_int("network.port")?.try_into() {
			Ok(port) => Ok(port),
			Err(e) => Err(Error::InvalidPort(e)),
		}?;

	Ok(Configuration {
		content_metadata_file: config.get_string("content.metadata_file")?,
		content_root: config.get_string("content.root")?.into(),
		email_from: config.get_string("email.from")?,
		email_live: config.get_bool("email.live")?,
		email_smtp_host: config.get_string("email.smtp_host")?,
		email_smtp_port,
		network_host: config.get_string("network.host")?,
		network_port,
	})
}