app.js 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. "use strict";
  2. // 常量定义
  3. const NUM_CLASSES = 8;
  4. const IMAGE_SIZE = 227; // MobileNet V1 通常需要 227x227 或 224x224,请根据实际模型调整
  5. const TOPK = 3;
  6. class Main {
  7. // --- 只有一个构造函数 ---
  8. constructor() {
  9. this.infoTexts = [];
  10. this.trainingClass = -1; // 用于标记哪个按钮被点击以进行训练
  11. this.currentImageElement = null; // 当前用于预览和处理的图片元素
  12. this.isAppReady = false;
  13. this.trainingButtons = []; // *** 在这里初始化 trainingButtons ***
  14. // 获取 DOM 元素
  15. this.fileInput = document.getElementById("file-input");
  16. this.imagePreview = document.getElementById("image-preview");
  17. this.imagePlaceholder = document.getElementById("preview-placeholder");
  18. this.classifyButton = document.getElementById("classify-button");
  19. this.controlsContainer = document.getElementById("controls-container");
  20. this.statusDiv = document.getElementById("status");
  21. // 使用 CDN 加载的全局变量
  22. this.tf = tf; // 来自 @tensorflow/tfjs
  23. this.knnClassifier = knnClassifier; // 来自 @tensorflow-models/knn-classifier
  24. this.mobilenet = mobilenet; // 来自 @tensorflow-models/mobilenet
  25. // 初始化页面和模型
  26. this.bindPage();
  27. // 创建训练按钮和信息文本 (现在可以安全调用了)
  28. this.createTrainingControls();
  29. // 绑定事件监听器
  30. this.bindEventListeners();
  31. }
  32. // --- 构造函数结束 ---
  33. updateStatus(message) {
  34. this.statusDiv.innerText = message;
  35. console.log(message); // 同时在控制台输出
  36. }
  37. createTrainingControls() {
  38. for (let i = 0; i < NUM_CLASSES; i++) {
  39. const div = document.createElement("div");
  40. div.classList.add("class-controls"); // 添加 CSS 类
  41. this.controlsContainer.appendChild(div);
  42. // 创建训练按钮
  43. const button = document.createElement("button");
  44. button.innerText = `训练物品 ${i}`;
  45. button.disabled = true; // 初始禁用,直到模型加载完成
  46. div.appendChild(button);
  47. // 点击按钮时,使用当前预览的图片进行训练
  48. button.addEventListener("click", async () => {
  49. if (
  50. this.imagePreview.style.display === "none" ||
  51. !this.imagePreview.src ||
  52. this.imagePreview.src.endsWith("#")
  53. ) {
  54. alert("请先选择一张图片再进行训练!");
  55. return;
  56. }
  57. this.updateStatus(`正在为物品 ${i} 添加样本...`);
  58. await this.addImageForTraining(i);
  59. this.updateStatus(`物品 ${i} 的样本已添加。`);
  60. });
  61. // 创建信息文本
  62. const infoText = document.createElement("span");
  63. infoText.innerText = " 0 个学习样例";
  64. div.appendChild(infoText);
  65. this.infoTexts.push(infoText);
  66. this.trainingButtons.push(button); // 保存按钮引用以便启用/禁用
  67. }
  68. }
  69. bindEventListeners() {
  70. // 文件选择事件
  71. this.fileInput.addEventListener("change", (event) => {
  72. const file = event.target.files[0];
  73. if (file && file.type.startsWith("image/")) {
  74. const reader = new FileReader();
  75. reader.onload = (e) => {
  76. this.imagePreview.src = e.target.result;
  77. this.imagePreview.style.display = "block";
  78. this.imagePlaceholder.style.display = "none"; // 隐藏占位符
  79. this.classifyButton.disabled = !this.isAppReady; // 启用识别按钮(如果应用已就绪)
  80. // 启用训练按钮
  81. this.trainingButtons.forEach(
  82. (btn) => (btn.disabled = !this.isAppReady)
  83. );
  84. this.updateStatus("图片已加载,可以进行识别或训练。");
  85. };
  86. reader.readAsDataURL(file);
  87. } else {
  88. // 清除预览并禁用按钮
  89. this.imagePreview.src = "#";
  90. this.imagePreview.style.display = "none";
  91. this.imagePlaceholder.style.display = "block"; // 显示占位符
  92. this.classifyButton.disabled = true;
  93. this.trainingButtons.forEach((btn) => (btn.disabled = true));
  94. if (file) {
  95. alert("请选择一个图片文件!");
  96. this.updateStatus("请选择一个图片文件。");
  97. } else {
  98. this.updateStatus("未选择图片。");
  99. }
  100. }
  101. });
  102. // 识别按钮点击事件
  103. this.classifyButton.addEventListener("click", async () => {
  104. if (
  105. this.imagePreview.style.display === "none" ||
  106. !this.imagePreview.src ||
  107. this.imagePreview.src.endsWith("#")
  108. ) {
  109. alert("请先选择一张图片再进行识别!");
  110. return;
  111. }
  112. await this.classifyCurrentImage();
  113. });
  114. }
  115. async bindPage() {
  116. try {
  117. this.updateStatus("正在初始化应用...");
  118. // 创建 KNN 分类器实例
  119. this.knn = this.knnClassifier.create();
  120. this.updateStatus("KNN 分类器已创建");
  121. // 加载 MobileNet 模型
  122. this.updateStatus("正在加载 MobileNet 模型...");
  123. try {
  124. this.mobilenetModel = await this.mobilenet.load({
  125. version: 1,
  126. alpha: 1.0,
  127. });
  128. console.log("MobileNet V1 模型加载完成");
  129. } catch (e) {
  130. console.warn("MobileNet V1 加载失败,尝试加载 V2:", e);
  131. this.mobilenetModel = await this.mobilenet.load({
  132. version: 2,
  133. alpha: 1.0,
  134. });
  135. console.log("MobileNet V2 模型加载完成");
  136. }
  137. this.updateStatus("模型加载完成,应用准备就绪。");
  138. this.isAppReady = true;
  139. // 如果用户在模型加载前就选了图片,现在启用按钮
  140. if (this.imagePreview.style.display !== "none") {
  141. this.classifyButton.disabled = false;
  142. this.trainingButtons.forEach((btn) => (btn.disabled = false));
  143. }
  144. } catch (error) {
  145. console.error("模型加载或初始化失败:", error);
  146. this.updateStatus("错误:模型加载失败。请检查网络或控制台。");
  147. alert("模型加载失败,请检查网络连接或控制台错误。");
  148. }
  149. }
  150. // 提取特征的辅助函数
  151. async getFeaturesFromImage(imageElement) {
  152. // *** 使用 tf.tidy 来管理此函数内部的张量 ***
  153. return tf.tidy(() => {
  154. let imageTensor = this.tf.browser.fromPixels(imageElement);
  155. // 调整大小并标准化
  156. let expectedSize = IMAGE_SIZE;
  157. try {
  158. if (
  159. this.mobilenetModel &&
  160. this.mobilenetModel.inputs &&
  161. this.mobilenetModel.inputs[0].shape
  162. ) {
  163. const shape = this.mobilenetModel.inputs[0].shape;
  164. if (
  165. shape &&
  166. shape.length >= 3 &&
  167. shape[1] != null &&
  168. shape[2] != null &&
  169. shape[1] === shape[2]
  170. ) {
  171. expectedSize = shape[1];
  172. console.log(`检测到模型输入尺寸: ${expectedSize}x${expectedSize}`);
  173. } else {
  174. console.warn(
  175. `无法从模型输入形状 ${JSON.stringify(
  176. shape
  177. )} 推断尺寸,使用默认值 ${IMAGE_SIZE}。`
  178. );
  179. }
  180. }
  181. } catch (e) {
  182. console.warn(
  183. `获取模型输入尺寸时出错,使用默认值 ${IMAGE_SIZE}。错误: ${e}`
  184. );
  185. }
  186. let resizedTensor = tf.image
  187. .resizeBilinear(imageTensor, [expectedSize, expectedSize], true)
  188. .expandDims(0);
  189. let normalizedTensor = resizedTensor.toFloat().div(127.5).sub(1.0);
  190. // 特征提取
  191. let logits = null;
  192. try {
  193. if (typeof this.mobilenetModel.infer === "function") {
  194. logits = this.mobilenetModel.infer(normalizedTensor, true);
  195. } else if (typeof this.mobilenetModel.predict === "function") {
  196. logits = this.mobilenetModel.predict(normalizedTensor);
  197. if (Array.isArray(logits)) {
  198. logits = logits[0];
  199. }
  200. if (logits.rank > 2) {
  201. const batchSize = logits.shape[0];
  202. logits = logits.reshape([batchSize, -1]);
  203. }
  204. } else {
  205. throw new Error("无法识别的模型接口");
  206. }
  207. } catch (error) {
  208. console.error("MobileNet 推理失败: ", error);
  209. return null; // 返回 null 表示失败
  210. }
  211. // *** tidy 会自动清理 imageTensor, resizedTensor, normalizedTensor ***
  212. // *** 但 logits 因为被返回,所以不会被这个 tidy 清理 ***
  213. return logits;
  214. }); // *** tidy 结束 ***
  215. }
  216. // 添加训练样本 (修改后)
  217. async addImageForTraining(classIndex) {
  218. if (!this.isAppReady) {
  219. alert("模型尚未准备好,请稍候。");
  220. return;
  221. }
  222. if (
  223. !this.imagePreview ||
  224. this.imagePreview.style.display === "none" ||
  225. !this.imagePreview.src ||
  226. this.imagePreview.src.endsWith("#")
  227. ) {
  228. alert("请先选择一张图片。");
  229. return;
  230. }
  231. // 1. 获取特征 (内部使用 tidy)
  232. const logits = await this.getFeaturesFromImage(this.imagePreview);
  233. if (logits) {
  234. // 2. 执行异步 KNN 操作 (不在 tidy 内部)
  235. this.knn.addExample(logits, classIndex);
  236. this.updateExampleCountUI(classIndex);
  237. // 3. **手动释放不再需要的 logits** (因为它被 getFeaturesFromImage 的 tidy 返回了)
  238. logits.dispose();
  239. } else {
  240. this.updateStatus(`错误:无法从图片提取特征用于训练物品 ${classIndex}`);
  241. }
  242. }
  243. // 分类当前图片 (修改后)
  244. async classifyCurrentImage() {
  245. if (!this.isAppReady) {
  246. alert("模型尚未准备好,请稍候。");
  247. return;
  248. }
  249. if (
  250. !this.imagePreview ||
  251. this.imagePreview.style.display === "none" ||
  252. !this.imagePreview.src ||
  253. this.imagePreview.src.endsWith("#")
  254. ) {
  255. alert("请先选择一张图片。");
  256. return;
  257. }
  258. this.updateStatus("正在识别图片...");
  259. const numClasses = this.knn.getNumClasses();
  260. if (numClasses === 0) {
  261. this.updateStatus("错误:您还没有训练任何物品类别!");
  262. alert("请先训练至少一个物品类别再进行识别。");
  263. return;
  264. }
  265. // 1. 获取特征 (内部使用 tidy)
  266. const logits = await this.getFeaturesFromImage(this.imagePreview);
  267. if (logits) {
  268. // 2. 执行异步 KNN 操作 (不在 tidy 内部)
  269. const res = await this.knn.predictClass(logits, TOPK);
  270. this.updatePredictionUI(res);
  271. this.updateStatus(
  272. `识别完成。最可能的类别: 物品 ${res.classIndex} (${(
  273. res.confidences[res.classIndex] * 100
  274. ).toFixed(1)}%)`
  275. );
  276. // 3. **手动释放不再需要的 logits**
  277. logits.dispose();
  278. } else {
  279. this.updateStatus("错误:无法从图片提取特征进行识别。");
  280. }
  281. }
  282. // 更新特定类别的样本计数UI (保持不变)
  283. updateExampleCountUI(classIndex) {
  284. const exampleCount = this.knn.getClassExampleCount();
  285. if (this.infoTexts[classIndex]) {
  286. const count = exampleCount[classIndex] || 0;
  287. const currentText = this.infoTexts[classIndex].innerText;
  288. const parts = currentText.split(" - ");
  289. this.infoTexts[classIndex].innerText =
  290. ` ${count} 个学习样例` + (parts.length > 1 ? ` - ${parts[1]}` : "");
  291. this.infoTexts[classIndex].style.fontWeight = "normal";
  292. }
  293. }
  294. // 更新所有类别的预测UI
  295. updatePredictionUI(predictionResult) {
  296. const exampleCount = this.knn.getClassExampleCount();
  297. for (let i = 0; i < NUM_CLASSES; i++) {
  298. const infoText = this.infoTexts[i];
  299. if (infoText) {
  300. // 确保元素存在
  301. const count = exampleCount[i] || 0;
  302. // 使用 ?. 确保 confidences 存在且属性存在
  303. const confidence =
  304. predictionResult?.confidences?.[i] != null
  305. ? (predictionResult.confidences[i] * 100).toFixed(1)
  306. : "0.0"; // 默认为 0%
  307. infoText.innerText = ` ${count} 个学习样例 - ${confidence}%`;
  308. // 加粗预测结果
  309. if (predictionResult.classIndex === i) {
  310. infoText.style.fontWeight = "bold";
  311. } else {
  312. infoText.style.fontWeight = "normal";
  313. }
  314. }
  315. }
  316. }
  317. }
  318. // 页面加载完成后启动应用
  319. window.addEventListener("load", () => new Main());