Mat2Java.zip
立即下载
资源介绍:
Mat2Java.zip
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.linear.*;
import org.apache.poi.ss.usermodel.*;
import org.apache.poi.xssf.usermodel.XSSFWorkbook;
import org.apache.commons.math3.special.Erf;
import org.knowm.xchart.SwingWrapper;
import org.knowm.xchart.XYChart;
import org.knowm.xchart.XYChartBuilder;
import org.knowm.xchart.style.lines.SeriesLines;
import org.knowm.xchart.style.markers.SeriesMarkers;
import java.awt.Color;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
//我在项目依赖中配置了poi库(5.2.5版本)、log4j(2.20版本)、common math库, xchart库等,详见pom.xml.
public class Analysis {
public static void main(String[] args) {
try {
// 1. 读取Excel文件
FileInputStream file = new FileInputStream("src/main/resources/data2.xlsx");
//为了方便读取,把原试验数据中需要的刺激量、响应量单独提取作为新表data2.xlsx
Workbook workbook = new XSSFWorkbook(file);
Sheet sheet = workbook.getSheetAt(0);
// 2. 从Excel文件中读取数据
int rowCount = sheet.getLastRowNum();
double[] x = new double[rowCount];
int[] v = new int[rowCount];
for (int i = 0; i < rowCount; i++) {
//手动没有读取编号行
Row row = sheet.getRow(i + 1);
if (row != null) {
Cell cellX = row.getCell(1); // 刺激量
Cell cellV = row.getCell(2); // 响应
if (cellX != null && cellX.getCellType() == CellType.NUMERIC) {
x[i] = cellX.getNumericCellValue();//一开始部分数据是文本类型,防止出现这种情况
}
if (cellV != null && cellV.getCellType() == CellType.NUMERIC) {
v[i] = (int) cellV.getNumericCellValue();
}
}
}
// System.out.println("刺激量: " + Arrays.toString(x));
// System.out.println("响应: " + Arrays.toString(v));
//计时开始 对应原tic
long startTime = System.nanoTime();
//3.初始化
double beta2 = 3.84146 / 2;
double x1l = findMin(x, v);
double x0u = findMax(x, v);
double nm = countInRange(x1l, x0u, x);
double mu0 = 0.5 * (x0u + x1l);
double sigma0 = rowCount * (x0u - x1l) / (8 * (nm + 2));
//4.进行循环迭代,将第一次迭代的初始化写入了循环内部
//设置一个计数器
int times = 1;
while (true) {
// 第一次迭代
double[] u = new double[rowCount];
double[] z = new double[rowCount];
double[] p = new double[rowCount];
double[] h = new double[rowCount];
for (int i = 0; i < rowCount; i++) {
u[i] = (x[i] - mu0) / sigma0;
z[i] = Math.exp(-0.5 * u[i] * u[i]) / Math.sqrt(2 * Math.PI);
p[i] = 0.5 * (1 + Erf.erf(u[i] / Math.sqrt(2)));
h[i] = v[i] / p[i] - (1 - v[i]) / (1 - p[i]);
}
//转置乘积
double f = dotProduct(z, h);
double g = dotProduct(u, z, h);
//fmu
double[] temp_pow = new double[z.length];
double[] temp_zh = elementMultiply(z, h);
for (int i = 0; i < z.length; i++) {
temp_pow[i] = Math.pow(temp_zh[i], 2);
}
double fmu = (g + dotProduct(temp_pow)) / sigma0;
//fsigma
double[] temp_uzh = elementMultiply(u, temp_zh);
double[] temp_u_zh = new double[temp_zh.length];
if (u.length != temp_zh.length) {
throw new IllegalArgumentException("向量维度不匹配。请检查参数");
} else {
for (int i = 0; i < u.length; i++) {
temp_u_zh[i] = u[i] + temp_zh[i];
}
}
double fsigma = dotProduct(elementMultiply(temp_uzh, temp_u_zh)) / sigma0;
//gsigma
double gmu = f / sigma0 + fsigma;
double[] finTemp = elementMultiply(elementMultiply(u, temp_uzh), temp_u_zh);
double gsigma = (dotProduct(finTemp) - g) / sigma0;
double[] sol = solveLinear(fmu, fsigma, gmu, gsigma, f, g);
//判断终止条件
if (Math.abs(sol[0]) + Math.abs(sol[1]) < 0.001) {
// System.out.println(sol[0]);
// System.out.println(sol[1]);
// System.out.println(times);
break;
}
//更新
mu0 += sol[0];
sigma0 += sol[1];
times++;
}
// %以0.999的概率可以发火的电压值
double x0999 = mu0 + 3.09 * sigma0;
//5.置信区间估计
List Vlist = new ArrayList<>();
List Mlist = new ArrayList<>();
for (int i = 0; i < v.length; i++) {
if (v[i] == 1) {
Vlist.add(x[i]);
} else if (v[i] == 0) {
Mlist.add(x[i]);
}
}
double[] V = Vlist.stream().mapToDouble(i -> i).toArray();
double[] M = Mlist.stream().mapToDouble(i -> i).toArray();
NormalDistribution distribution = new NormalDistribution(mu0, sigma0);
//计算Lx,省略了构造新表Vp,Mp.
double Lx = 0;
for (double vi : V) {
Lx += Math.log(distribution.cumulativeProbability(vi));
}
for (double mi : M) {
Lx += Math.log(1 - distribution.cumulativeProbability(mi));
}
//计算LL,类似简化
double mu_start = mu0 - 4 * sigma0;
double mu_end = mu0;
//二分法求最小mu
double mu_min = 0;
double mu_max = 0;
while (true) {
double mu_m = (mu_end + mu_start) / 2;
NormalDistribution distribution1 = new NormalDistribution(mu_m, sigma0);
double LL = 0;
for (double vi : V) {
LL += Math.log(distribution1.cumulativeProbability(vi));
}
for (double mi : M) {
LL += Math.log(1 - distribution1.cumulativeProbability(mi));
}
double d_LL = LL - (Lx - beta2);
if (Math.abs(d_LL) < 1e-5) {
mu_min = mu_m;
mu_max = 2 * mu0 - mu_m;
break;
}
if (d_LL < 0) {
mu_start = mu_m;
} else {
mu_end = mu_m;
}
}
//遍历四个区域获得(mu, sigma0)对
double d_mu = 0.0001;
double d_sigma = 0.0001;
List mu_sg_edge = new ArrayList<>();
mu_sg_edge.addAll(searchBoundary(1, mu0, mu_min, mu_max, sigma0,
d_mu, d_sigma, V, M, Lx, beta2));
mu_sg_edge.addAll(searchBoundary(2, mu0, mu_min, mu_max, sigma0,
d_mu, d_sigma, V, M, Lx, beta2));
mu_sg_edge.addAll(searchBoundary(3, mu0, mu_min, mu_max, sigma0,