brain-js.test-04-classifier.html 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. <!DOCTYPE html>
  2. <html>
  3. <head>
  4. <title>Brain JS - NeuralNetwork</title>
  5. <meta charset="utf-8" />
  6. <meta http-equiv="Content-Type"
  7. content="text/html; charset=utf-8" />
  8. <style>
  9. #network {
  10. background: #ccc;
  11. border: 1px solid red;
  12. width: 640px;
  13. height: 480px;
  14. }
  15. </style>
  16. </head>
  17. <body>
  18. <h1>Brain JS - NeuralNetwork</h1>
  19. <p>
  20. In this case, we test the math adding calcuation by Neural Network model.
  21. </p>
  22. <h2>Test data</h2>
  23. <p>
  24. This chart shows the all sample data.
  25. The test is to find some point located on certain place, where is close
  26. to (0.15, 0) with radius 0.25. The network will trained with some of
  27. these data and guess other part to draw the results.
  28. </p>
  29. <div id="all-data"
  30. style="position: relative; width:600px">
  31. <canvas id="allDataChart"></canvas>
  32. </div>
  33. <p>This chart shows the test results by the network.
  34. The diagram shall be part of above chart.
  35. </p>
  36. <div id="test-data"
  37. style="position: relative; width:600px">
  38. <canvas id="testDataChart"></canvas>
  39. </div>
  40. <h2>Network</h2>
  41. <p>The network structure is shown as below:</p>
  42. <div id="network"></div>
  43. </body>
  44. <script src="brain-browser.js"></script>
  45. <script src="chart.js"></script>
  46. <script type="text/javascript">
  47. function criteria(x, y) {
  48. const r = Math.sqrt((x - 0.15) ** 2 + y ** 2)
  49. if (r < 0.25) {
  50. return 1;
  51. } else {
  52. return 0;
  53. }
  54. }
  55. function generator(num = 1000) {
  56. let data = [];
  57. for (let i = 0; i < num; i++) {
  58. const x = Math.random() - 0.5;
  59. const y = Math.random() - 0.5;
  60. data.push({
  61. x: x,
  62. y: y,
  63. })
  64. }
  65. return data;
  66. }
  67. function generateInputOutput(data_input, data_output, split) {
  68. const train_input = data_input.slice(0, split);
  69. const train_output = data_output.slice(0, split);
  70. const test_input = data_input.slice(split);
  71. const test_output = data_output.slice(split);
  72. let results = {
  73. train: [],
  74. test: {
  75. input: [],
  76. output: [],
  77. },
  78. all: [],
  79. };
  80. for (let i = 0; i < train_input.length; i++) {
  81. results['train'].push({
  82. // input: [train_input[i]['x'],train_input[i]['y']], // RNN
  83. input: train_input[i], // NeuralNetwork
  84. output: [train_output[i]],
  85. });
  86. }
  87. for (let i = 0; i < test_input.length; i++) {
  88. results['test']['input'].push(test_input[i]);
  89. results['test']['output'].push(test_output[i]);
  90. }
  91. for (let i = 0; i < data_input.length; i++) {
  92. results['all'].push({
  93. input: data_input[i],
  94. output: [data_output[i]],
  95. });
  96. }
  97. return results;
  98. }
  99. // Generate the datasets for the chart
  100. function generateChartsData(data) {
  101. let datasets0 = {
  102. label: '0',
  103. data: [],
  104. backgroundColor: 'rgba(50,0,0,0.1)',
  105. radius: 6,
  106. tyle: 'cicle',
  107. };
  108. let datasets1 = {
  109. label: '1',
  110. data: [],
  111. backgroundColor: 'rgba(0, 100,200,0.05)',
  112. radius: 10,
  113. tyle: 'triangle',
  114. };
  115. for (let i = 0; i < data.length; i++) {
  116. let item = { x: data[i]['input']['x'], y: data[i]['input']['y'] }
  117. if (data[i]['output'] == 0) {
  118. datasets0['data'].push(item);
  119. } else {
  120. datasets1['data'].push(item);
  121. }
  122. }
  123. return [datasets0, datasets1];
  124. }
  125. const TOTAL_SAMPLES = 100;
  126. const TRAIN_SAMPLES = 60;
  127. const data_input = generator(TOTAL_SAMPLES);
  128. const data_output = data_input.map(d => criteria(d['x'], d['y']));
  129. const dataset = generateInputOutput(data_input, data_output, TRAIN_SAMPLES);
  130. const ctx = document.getElementById("allDataChart").getContext('2d');
  131. var scatterChart = new Chart(ctx, {
  132. type: "scatter",
  133. data: {
  134. datasets: generateChartsData(dataset['all'])
  135. },
  136. options: {
  137. responsive: true,
  138. }
  139. });
  140. const svg_config = {
  141. height: 480,
  142. width: 640,
  143. radius: 10,
  144. };
  145. console.log(dataset);
  146. // create a simple recurrent neural network
  147. const net = new brain.NeuralNetwork();
  148. net.train(dataset['train'], { log: true, errorThresh: 0.03 });
  149. let test_predict = [];
  150. dataset['test']['input'].forEach((item) => {
  151. test_predict.push(Math.floor(net.run(item)[0] + 0.5));
  152. });
  153. console.log(test_predict);
  154. let testChartsData = [];
  155. for (let i = 0; i < dataset['test']['input'].length; i++) {
  156. testChartsData.push({
  157. input: dataset['test']['input'][i],
  158. output: test_predict[i],
  159. })
  160. }
  161. const ctxTest = document.getElementById("testDataChart").getContext('2d');
  162. const scatterChartTest = new Chart(ctxTest, {
  163. type: "scatter",
  164. data: {
  165. datasets: generateChartsData(testChartsData),
  166. },
  167. options: {
  168. responsive: true,
  169. }
  170. });
  171. document.getElementById("network").innerHTML = brain.utilities.toSVG(
  172. net,
  173. svg_config
  174. );
  175. </script>
  176. </html>