Add test for intcox
This commit is contained in:
		
							parent
							
								
									46e3b189ce
								
							
						
					
					
						commit
						6ac2d9f055
					
				| @ -12,6 +12,9 @@ rayon = "1.7.0" | ||||
| serde = { version = "1.0.160", features = ["derive"] } | ||||
| serde_json = "1.0.96" | ||||
| 
 | ||||
| [profile.test] | ||||
| opt-level = 3 | ||||
| 
 | ||||
| [profile.perf] | ||||
| inherits = "release" | ||||
| debug = true | ||||
|  | ||||
| @ -162,7 +162,7 @@ impl IntervalCensoredCoxData { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatrix<f64>, max_iterations: u32, tolerance: f64, reduced: bool, progress_bar: ProgressBar) -> IntervalCensoredCoxResult { | ||||
| pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatrix<f64>, max_iterations: u32, tolerance: f64, reduced: bool, progress_bar: ProgressBar) -> IntervalCensoredCoxResult { | ||||
| 	// ----------------------
 | ||||
| 	// Prepare for regression
 | ||||
| 	
 | ||||
| @ -572,11 +572,11 @@ fn profile_log_likelihood_obs(data: &IntervalCensoredCoxData, beta: DVector<f64> | ||||
| } | ||||
| 
 | ||||
| #[derive(Serialize, Deserialize)] | ||||
| struct IntervalCensoredCoxResult { | ||||
| 	params: Vec<f64>, | ||||
| 	params_se: Vec<f64>, | ||||
| 	ll_model: f64, | ||||
| 	ll_null: f64, | ||||
| pub struct IntervalCensoredCoxResult { | ||||
| 	pub params: Vec<f64>, | ||||
| 	pub params_se: Vec<f64>, | ||||
| 	pub ll_model: f64, | ||||
| 	pub ll_null: f64, | ||||
| 	// TODO: cumulative hazard, etc.
 | ||||
| } | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										1
									
								
								src/lib.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								src/lib.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | ||||
| pub mod intcox; | ||||
| @ -16,7 +16,7 @@ | ||||
| 
 | ||||
| use clap::{Parser, Subcommand}; | ||||
| 
 | ||||
| mod intcox; | ||||
| use hpstat::intcox; | ||||
| 
 | ||||
| #[derive(Parser)] | ||||
| #[command(about="High-performance statistics implementations")] | ||||
|  | ||||
							
								
								
									
										87
									
								
								tests/intcox.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								tests/intcox.rs
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,87 @@ | ||||
| // hpstat: High-performance statistics implementations
 | ||||
| // Copyright © 2023  Lee Yingtong Li (RunasSudo)
 | ||||
| // 
 | ||||
| // This program is free software: you can redistribute it and/or modify
 | ||||
| // it under the terms of the GNU Affero General Public License as published by
 | ||||
| // the Free Software Foundation, either version 3 of the License, or
 | ||||
| // (at your option) any later version.
 | ||||
| // 
 | ||||
| // This program is distributed in the hope that it will be useful,
 | ||||
| // but WITHOUT ANY WARRANTY; without even the implied warranty of
 | ||||
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | ||||
| // GNU Affero General Public License for more details.
 | ||||
| // 
 | ||||
| // You should have received a copy of the GNU Affero General Public License
 | ||||
| // along with this program.  If not, see <https://www.gnu.org/licenses/>.
 | ||||
| 
 | ||||
| use std::fs; | ||||
| 
 | ||||
| use indicatif::ProgressBar; | ||||
| use nalgebra::DMatrix; | ||||
| 
 | ||||
| use hpstat::intcox::fit_interval_censored_cox; | ||||
| 
 | ||||
| #[test] | ||||
| fn test_intcox_zeng_mao_lin() { | ||||
| 	// Compare "Bangkok Metropolitan Administration HIV" data from Zeng, Mao & Lin (2016) with IntCens 0.2 output
 | ||||
| 	
 | ||||
| 	let contents = fs::read_to_string("tests/zeng_mao_lin.csv").unwrap(); | ||||
| 	let lines: Vec<String> = contents.trim_end().split("\n").map(|s| s.to_string()).collect(); | ||||
| 	
 | ||||
| 	// Read data into matrices
 | ||||
| 	
 | ||||
| 	let mut data_times: DMatrix<f64> = DMatrix::zeros( | ||||
| 		2,  // Left time, right time
 | ||||
| 		lines.len() - 1  // Minus 1 row for header row
 | ||||
| 	); | ||||
| 	
 | ||||
| 	// Called "Z" in the paper and "X" in the C++ code
 | ||||
| 	let mut data_indep: DMatrix<f64> = DMatrix::zeros( | ||||
| 		lines[0].split(",").count() - 2, | ||||
| 		lines.len() - 1  // Minus 1 row for header row
 | ||||
| 	); | ||||
| 	
 | ||||
| 	// Read data
 | ||||
| 	// FIXME: Parse CSV more robustly
 | ||||
| 	for (i, row) in lines.iter().skip(1).enumerate() { | ||||
| 		for (j, item) in row.split(",").enumerate() { | ||||
| 			let value = match item { | ||||
| 				"inf" => f64::INFINITY, | ||||
| 				_ => item.parse().expect("Malformed float") | ||||
| 			}; | ||||
| 			
 | ||||
| 			if j < 2 { | ||||
| 				data_times[(j, i)] = value; | ||||
| 			} else { | ||||
| 				data_indep[(j - 2, i)] = value; | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	
 | ||||
| 	// Fit regression
 | ||||
| 	let progress_bar = ProgressBar::hidden(); | ||||
| 	//let result = fit_interval_censored_cox(data_times, data_indep, 200, 0.00005, false, progress_bar);
 | ||||
| 	let result = fit_interval_censored_cox(data_times, data_indep, 100, 0.0001, false, progress_bar); | ||||
| 	
 | ||||
| 	// ./unireg --in zeng_mao_lin.csv --out out.txt --r 0.0 --model "(Left_Time, Right_Time) = Needle + Needle2 + LogAge + GenderM + RaceO + RaceW + GenderM_RaceO + GenderM_RaceW" --sep , --inf_char inf --convergence_threshold 0.002
 | ||||
| 	
 | ||||
| 	assert!((result.ll_model - -603.205).abs() < 1.0); | ||||
| 	
 | ||||
| 	assert!((result.params[0] - -0.18636961816695094).abs() < 0.01); | ||||
| 	assert!((result.params[1] - 0.080478699024478656).abs() < 0.01); | ||||
| 	assert!((result.params[2] - -0.71260450817296639).abs() < 0.01); | ||||
| 	assert!((result.params[3] - -0.22937443803422858).abs() < 0.01); | ||||
| 	assert!((result.params[4] - -0.14101449484871434).abs() < 0.01); | ||||
| 	assert!((result.params[5] - -0.43894526362102332).abs() < 0.01); | ||||
| 	assert!((result.params[6] - 0.064533885082884768).abs() < 0.01); | ||||
| 	assert!((result.params[7] - 0.20970425315378016).abs() < 0.01); | ||||
| 	
 | ||||
| 	assert!((result.params_se[0] - 0.41496954829036448).abs() < 0.01); | ||||
| 	assert!((result.params_se[1] - 0.15086156546712554).abs() < 0.01); | ||||
| 	assert!((result.params_se[2] - 0.36522062865858951).abs() < 0.01); | ||||
| 	assert!((result.params_se[3] - 0.32195496906604004).abs() < 0.01); | ||||
| 	assert!((result.params_se[4] - 0.3912241733944129).abs() < 0.01); | ||||
| 	assert!((result.params_se[5] - 0.41907763222198746).abs() < 0.01); | ||||
| 	assert!((result.params_se[6] - 0.45849947730170948).abs() < 0.01); | ||||
| 	assert!((result.params_se[7] - 0.48803508171247434).abs() < 0.01); | ||||
| } | ||||
							
								
								
									
										1125
									
								
								tests/zeng_mao_lin.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1125
									
								
								tests/zeng_mao_lin.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user