소스 검색

Add LSTM test case

Jason Xing 3 년 전
부모
커밋
b3a4057feb
2개의 변경된 파일104개의 추가작업 그리고 3개의 파일을 삭제
  1. 3 3
      brain-js.test-01.html
  2. 101 0
      brain-js.test-03-LSTM-math.html

+ 3 - 3
brain-js.test-01.html

@@ -32,11 +32,11 @@
     // provide optional config object (or undefined). Defaults shown.
     const config = {
       binaryThresh: 0.5,
-      hiddenLayers: [3], // array of ints for the sizes of the hidden layers in the network
+      hiddenLayers: [5], // array of ints for the sizes of the hidden layers in the network
       activation: "relu", // supported activation types: ['sigmoid', 'relu', 'leaky-relu', 'tanh'],
       leakyReluAlpha: 0.01, // supported for activation type 'leaky-relu'
-      inputSize: 2,
-      outputSize: 1,
+      // inputSize: 2,
+      // outputSize: 2,
     };
 
     // create a simple feed forward neural network with backpropagation

+ 101 - 0
brain-js.test-03-LSTM-math.html

@@ -0,0 +1,101 @@
+<!DOCTYPE html>
+<html>
+
+<head>
+    <title>Brain JS - LSTM Math</title>
+    <meta charset="utf-8" />
+    <meta http-equiv="Content-Type"
+          content="text/html; charset=utf-8" />
+    <style>
+        #network {
+        background: #ccc;
+        border: 1px solid red;
+        width: 640px;
+        height: 480px;
+      }
+    </style>
+</head>
+
+<body>
+    <h1>Brain JS - LSTM Math</h1>
+    <p>
+        In this case, we test the math adding calcuation by LSTM model.
+    </p>
+    <h2>Test results</h2>
+    <div id="input">
+        <p>Input a math problem, like "1+2=" or "9+7=" in the below input box.</p>
+        <input type="text"
+               id="problem"></input>
+        <button id="calcuate">Calculate</button>
+        <button id="testall">Test all once</button>
+    </div>
+    <div id="output"></div>
+    <h2>Network</h2>
+    <div id="network"></div>
+</body>
+<script src="brain-browser.js"></script>
+<script type="text/javascript">
+const svg_config = {
+  height: 480,
+  width: 640,
+  radius: 10,
+};
+
+// create a simple recurrent neural network
+const net = new brain.recurrent.LSTM();
+
+// used to build list below
+const mathProblemsSet = new Set();
+for (let i = 0; i < 10; i++) {
+  for (let j = 0; j < 10; j++) {
+    mathProblemsSet.add(`${i}+${j}=${i + j}`);
+    mathProblemsSet.add(`${j}+${i}=${i + j}`);
+  }
+}
+const mathProblems = Array.from(mathProblemsSet);
+console.log(mathProblems);
+
+net.train(mathProblems, { log: true, errorThresh: 0.03 });
+
+
+let errors = 0;
+document.getElementById("calcuate").onclick = () => {
+  const input = document.getElementById("problem").value;
+  const output = net.run(input)
+
+  const predictedMathProblem = input + output;
+  const error = mathProblems.indexOf(predictedMathProblem) <= -1 ? 1 : 0;
+  errors += error;
+
+  document.getElementById("output").innerHTML =
+    `Output :${input}${output}` +
+    `<br/>${error?'Error':'Correct'}, total error ${errors}`;
+};
+
+// Batch test for all problems
+document.getElementById("testall").onclick = () => {
+  let results = 'Test all items:<br/><ul>';
+  for (let i = 0; i < mathProblems.length; i++) {
+    const input = mathProblems[i].split('=')[0] + '=';
+    const output = net.run(input);
+    const predictedMathProblem = input + output;
+    const error = mathProblems.indexOf(predictedMathProblem) <= -1 ? 1 : 0;
+    errors += error;
+    results += '<li>';
+    results += input + output;
+    results += error ? ' - error' : '';
+    results += '</li>';
+  }
+  results +='</ul>';
+
+  results = `Batch test with ${errors} errors <br/>`+ results;
+
+  document.getElementById("output").innerHTML = results;
+};
+document.getElementById("network").innerHTML = brain.utilities.toSVG(
+  net,
+  svg_config
+);
+</script>
+
+</html>