1#[cfg(feature = "alloc")]
7use alloc::string::ToString;
8
9use super::{
10 EntropyValidator,
11 SecurityConstants,
12 TimingValidator,
13};
14use crate::api::{
15 Algorithm,
16 AlgorithmCategory,
17};
18use crate::error::Result;
19
20#[cfg(feature = "alloc")]
26#[derive(Clone)]
27pub struct SecurityValidator {
28 timing_validator: TimingValidator,
29 entropy_validator: EntropyValidator,
30 constants: SecurityConstants,
31}
32
33#[cfg(feature = "alloc")]
34impl SecurityValidator {
35 pub fn new() -> Result<Self> {
45 Ok(Self {
46 timing_validator: TimingValidator::new()?,
47 entropy_validator: EntropyValidator::new()?,
48 constants: SecurityConstants::new(),
49 })
50 }
51
52 pub fn validate_algorithm_category(
63 &self,
64 algorithm: Algorithm,
65 expected_category: AlgorithmCategory,
66 ) -> Result<()> {
67 if !algorithm.supports_category(expected_category) {
68 return Err(crate::error::Error::InvalidAlgorithm {
69 algorithm: "Algorithm category mismatch",
70 });
71 }
72 Ok(())
73 }
74
75 pub fn validate_key_size(
87 &self,
88 algorithm: Algorithm,
89 key_data: &[u8],
90 is_secret: bool,
91 ) -> Result<()> {
92 let expected_size = self.constants.get_expected_key_size(algorithm, is_secret)?;
93
94 if key_data.len() != expected_size {
95 return Err(crate::error::Error::InvalidKeySize {
96 expected: expected_size,
97 actual: key_data.len(),
98 });
99 }
100
101 Ok(())
102 }
103
104 fn ensure_non_trivial_key_bytes(key_data: &[u8]) -> Result<()> {
105 if key_data.is_empty() {
106 return Err(crate::error::Error::InvalidKeySize {
107 expected: 1,
108 actual: 0,
109 });
110 }
111
112 if key_data.iter().all(|&b| b == 0) {
113 return Err(crate::error::Error::InvalidKey {
114 key_type: "key".to_string(),
115 reason: "Key material cannot be all zeros".to_string(),
116 });
117 }
118
119 if key_data.iter().all(|&b| b == 0xFF) {
120 return Err(crate::error::Error::InvalidKey {
121 key_type: "key".to_string(),
122 reason: "Key material cannot be all ones".to_string(),
123 });
124 }
125
126 Ok(())
127 }
128
129 pub fn validate_key_material(&self, key_data: &[u8]) -> Result<()> {
139 Self::ensure_non_trivial_key_bytes(key_data)?;
140 self.entropy_validator.validate_key_entropy(key_data)?;
141 Ok(())
142 }
143
144 pub fn validate_public_key(&self, algorithm: Algorithm, key_data: &[u8]) -> Result<()> {
155 self.validate_key_size(algorithm, key_data, false)?;
156 self.validate_key_material(key_data)?;
157 Ok(())
158 }
159
160 pub fn validate_secret_key(&self, algorithm: Algorithm, key_data: &[u8]) -> Result<()> {
171 self.validate_key_size(algorithm, key_data, true)?;
172 Self::ensure_non_trivial_key_bytes(key_data)?;
173
174 self.entropy_validator.validate_key_entropy(key_data)?;
175
176 Ok(())
177 }
178
179 pub fn validate_aead_message(&self, message: &[u8]) -> Result<()> {
181 if message.len() > self.constants.max_aead_message_size() {
182 return Err(crate::error::Error::InvalidMessageSize {
183 max: self.constants.max_aead_message_size(),
184 actual: message.len(),
185 });
186 }
187 Ok(())
188 }
189
190 pub fn validate_hash_input(&self, data: &[u8]) -> Result<()> {
192 if data.len() > self.constants.max_hash_message_size() {
193 return Err(crate::error::Error::InvalidMessageSize {
194 max: self.constants.max_hash_message_size(),
195 actual: data.len(),
196 });
197 }
198 Ok(())
199 }
200
201 pub fn validate_signature_message(&self, message: &[u8]) -> Result<()> {
203 self.validate_hash_input(message)
204 }
205
206 pub fn security_constants(&self) -> &SecurityConstants {
208 &self.constants
209 }
210
211 pub fn security_constants_mut(&mut self) -> &mut SecurityConstants {
213 &mut self.constants
214 }
215
216 pub fn validate_nonce(&self, nonce: &[u8]) -> Result<()> {
226 if nonce.len() != self.constants.standard_nonce_size() {
227 return Err(crate::error::Error::InvalidNonceSize {
228 expected: self.constants.standard_nonce_size(),
229 actual: nonce.len(),
230 });
231 }
232
233 if nonce.iter().all(|&b| b == 0) {
235 return Err(crate::error::Error::InvalidKey {
236 key_type: "nonce".to_string(),
237 reason: "Nonce cannot be all zeros".to_string(),
238 });
239 }
240
241 Ok(())
242 }
243
244 pub fn validate_ciphertext(&self, algorithm: Algorithm, ciphertext: &[u8]) -> Result<()> {
255 if ciphertext.is_empty() {
256 return Err(crate::error::Error::InvalidCiphertextSize {
257 expected: 1,
258 actual: 0,
259 });
260 }
261
262 let expected_size = self.constants.get_expected_ciphertext_size(algorithm)?;
263 if ciphertext.len() != expected_size {
264 return Err(crate::error::Error::InvalidCiphertextSize {
265 expected: expected_size,
266 actual: ciphertext.len(),
267 });
268 }
269
270 Ok(())
271 }
272
273 pub fn validate_signature(&self, algorithm: Algorithm, signature: &[u8]) -> Result<()> {
284 if signature.is_empty() {
285 return Err(crate::error::Error::InvalidSignatureSize {
286 expected: 1,
287 actual: 0,
288 });
289 }
290
291 let expected_size = self.constants.get_expected_signature_size(algorithm)?;
292 if signature.len() != expected_size {
293 return Err(crate::error::Error::InvalidSignatureSize {
294 expected: expected_size,
295 actual: signature.len(),
296 });
297 }
298
299 Ok(())
300 }
301
302 pub fn validate_randomness(&self, randomness: &[u8]) -> Result<()> {
312 if randomness.len() < self.constants.min_randomness_size() {
313 return Err(crate::error::Error::InvalidKeySize {
314 expected: self.constants.min_randomness_size(),
315 actual: randomness.len(),
316 });
317 }
318
319 self.validate_key_material(randomness)?;
320 Ok(())
321 }
322
323 pub fn constant_time_compare(&self, a: &[u8], b: &[u8]) -> bool {
335 self.timing_validator.constant_time_compare(a, b)
336 }
337
338 pub fn entropy_validator(&self) -> &EntropyValidator {
343 &self.entropy_validator
344 }
345
346 pub fn entropy_validator_mut(&mut self) -> &mut EntropyValidator {
356 &mut self.entropy_validator
357 }
358}
359
360#[cfg(test)]
361#[cfg(feature = "alloc")]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn test_security_validator_creation() {
367 let validator = SecurityValidator::new();
368 assert!(
369 validator.is_ok(),
370 "SecurityValidator should be created successfully"
371 );
372 }
373
374 #[test]
375 fn test_validate_algorithm_category() {
376 let validator = SecurityValidator::new().unwrap();
377
378 let result =
380 validator.validate_algorithm_category(Algorithm::MlKem512, AlgorithmCategory::Kem);
381 assert!(result.is_ok(), "Should accept correct algorithm category");
382
383 let result = validator
385 .validate_algorithm_category(Algorithm::MlKem512, AlgorithmCategory::Signature);
386 assert!(
387 result.is_err(),
388 "Should reject incorrect algorithm category"
389 );
390 }
391
392 #[test]
393 fn test_validate_key_material() {
394 let validator = SecurityValidator::new().unwrap();
395
396 let valid_key = vec![
398 0x1A, 0x2B, 0x3C, 0x4D, 0x5E, 0x6F, 0x70, 0x81, 0x92, 0xA3, 0xB4, 0xC5, 0xD6, 0xE7,
399 0xF8, 0x09,
400 ];
401 let result = validator.validate_key_material(&valid_key);
402 assert!(result.is_ok(), "Should accept valid key material");
403
404 let zero_key = vec![0u8; 8];
406 let result = validator.validate_key_material(&zero_key);
407 assert!(result.is_err(), "Should reject zero key");
408
409 let ones_key = vec![0xFFu8; 8];
411 let result = validator.validate_key_material(&ones_key);
412 assert!(result.is_err(), "Should reject all-ones key");
413
414 let empty_key = vec![];
416 let result = validator.validate_key_material(&empty_key);
417 assert!(result.is_err(), "Should reject empty key");
418 }
419
420 #[test]
421 fn test_validate_aead_message() {
422 let validator = SecurityValidator::new().unwrap();
423
424 let valid_message = vec![1u8; 1000];
425 assert!(validator.validate_aead_message(&valid_message).is_ok());
426
427 let oversized_message = vec![1u8; 2 * 1024 * 1024];
428 assert!(validator.validate_aead_message(&oversized_message).is_err());
429 }
430
431 #[test]
432 fn test_validate_hash_input_default_unbounded() {
433 let validator = SecurityValidator::new().unwrap();
434 let large = vec![1u8; 2 * 1024 * 1024];
435 assert!(validator.validate_hash_input(&large).is_ok());
436 }
437
438 #[test]
439 fn test_validate_hash_input_when_capped() {
440 let mut validator = SecurityValidator::new().unwrap();
441 validator
442 .security_constants_mut()
443 .set_max_hash_message_size(1024);
444 let oversized = vec![1u8; 2048];
445 assert!(validator.validate_hash_input(&oversized).is_err());
446 }
447
448 #[test]
449 fn test_constant_time_compare() {
450 let validator = SecurityValidator::new().unwrap();
451
452 let a = vec![1, 2, 3, 4];
454 let b = vec![1, 2, 3, 4];
455 assert!(
456 validator.constant_time_compare(&a, &b),
457 "Should return true for equal slices"
458 );
459
460 let c = vec![1, 2, 3, 5];
462 assert!(
463 !validator.constant_time_compare(&a, &c),
464 "Should return false for different slices"
465 );
466
467 let d = vec![1, 2, 3];
469 assert!(
470 !validator.constant_time_compare(&a, &d),
471 "Should return false for different length slices"
472 );
473 }
474}