"use strict"; // 常量定义 const NUM_CLASSES = 8; const IMAGE_SIZE = 227; // MobileNet V1 通常需要 227x227 或 224x224,请根据实际模型调整 const TOPK = 3; class Main { // --- 只有一个构造函数 --- constructor() { this.infoTexts = []; this.trainingClass = -1; // 用于标记哪个按钮被点击以进行训练 this.currentImageElement = null; // 当前用于预览和处理的图片元素 this.isAppReady = false; this.trainingButtons = []; // *** 在这里初始化 trainingButtons *** // 获取 DOM 元素 this.fileInput = document.getElementById("file-input"); this.imagePreview = document.getElementById("image-preview"); this.imagePlaceholder = document.getElementById("preview-placeholder"); this.classifyButton = document.getElementById("classify-button"); this.controlsContainer = document.getElementById("controls-container"); this.statusDiv = document.getElementById("status"); // 使用 CDN 加载的全局变量 this.tf = tf; // 来自 @tensorflow/tfjs this.knnClassifier = knnClassifier; // 来自 @tensorflow-models/knn-classifier this.mobilenet = mobilenet; // 来自 @tensorflow-models/mobilenet // 初始化页面和模型 this.bindPage(); // 创建训练按钮和信息文本 (现在可以安全调用了) this.createTrainingControls(); // 绑定事件监听器 this.bindEventListeners(); } // --- 构造函数结束 --- updateStatus(message) { this.statusDiv.innerText = message; console.log(message); // 同时在控制台输出 } createTrainingControls() { for (let i = 0; i < NUM_CLASSES; i++) { const div = document.createElement("div"); div.classList.add("class-controls"); // 添加 CSS 类 this.controlsContainer.appendChild(div); // 创建训练按钮 const button = document.createElement("button"); button.innerText = `训练物品 ${i}`; button.disabled = true; // 初始禁用,直到模型加载完成 div.appendChild(button); // 点击按钮时,使用当前预览的图片进行训练 button.addEventListener("click", async () => { if ( this.imagePreview.style.display === "none" || !this.imagePreview.src || this.imagePreview.src.endsWith("#") ) { alert("请先选择一张图片再进行训练!"); return; } this.updateStatus(`正在为物品 ${i} 添加样本...`); await this.addImageForTraining(i); this.updateStatus(`物品 ${i} 的样本已添加。`); }); // 创建信息文本 const infoText = document.createElement("span"); infoText.innerText = " 0 个学习样例"; div.appendChild(infoText); this.infoTexts.push(infoText); this.trainingButtons.push(button); // 保存按钮引用以便启用/禁用 } } bindEventListeners() { // 文件选择事件 this.fileInput.addEventListener("change", (event) => { const file = event.target.files[0]; if (file && file.type.startsWith("image/")) { const reader = new FileReader(); reader.onload = (e) => { this.imagePreview.src = e.target.result; this.imagePreview.style.display = "block"; this.imagePlaceholder.style.display = "none"; // 隐藏占位符 this.classifyButton.disabled = !this.isAppReady; // 启用识别按钮(如果应用已就绪) // 启用训练按钮 this.trainingButtons.forEach( (btn) => (btn.disabled = !this.isAppReady) ); this.updateStatus("图片已加载,可以进行识别或训练。"); }; reader.readAsDataURL(file); } else { // 清除预览并禁用按钮 this.imagePreview.src = "#"; this.imagePreview.style.display = "none"; this.imagePlaceholder.style.display = "block"; // 显示占位符 this.classifyButton.disabled = true; this.trainingButtons.forEach((btn) => (btn.disabled = true)); if (file) { alert("请选择一个图片文件!"); this.updateStatus("请选择一个图片文件。"); } else { this.updateStatus("未选择图片。"); } } }); // 识别按钮点击事件 this.classifyButton.addEventListener("click", async () => { if ( this.imagePreview.style.display === "none" || !this.imagePreview.src || this.imagePreview.src.endsWith("#") ) { alert("请先选择一张图片再进行识别!"); return; } await this.classifyCurrentImage(); }); } async bindPage() { try { this.updateStatus("正在初始化应用..."); // 创建 KNN 分类器实例 this.knn = this.knnClassifier.create(); this.updateStatus("KNN 分类器已创建"); // 加载 MobileNet 模型 this.updateStatus("正在加载 MobileNet 模型..."); try { this.mobilenetModel = await this.mobilenet.load({ version: 1, alpha: 1.0, }); console.log("MobileNet V1 模型加载完成"); } catch (e) { console.warn("MobileNet V1 加载失败,尝试加载 V2:", e); this.mobilenetModel = await this.mobilenet.load({ version: 2, alpha: 1.0, }); console.log("MobileNet V2 模型加载完成"); } this.updateStatus("模型加载完成,应用准备就绪。"); this.isAppReady = true; // 如果用户在模型加载前就选了图片,现在启用按钮 if (this.imagePreview.style.display !== "none") { this.classifyButton.disabled = false; this.trainingButtons.forEach((btn) => (btn.disabled = false)); } } catch (error) { console.error("模型加载或初始化失败:", error); this.updateStatus("错误:模型加载失败。请检查网络或控制台。"); alert("模型加载失败,请检查网络连接或控制台错误。"); } } // 提取特征的辅助函数 async getFeaturesFromImage(imageElement) { // *** 使用 tf.tidy 来管理此函数内部的张量 *** return tf.tidy(() => { let imageTensor = this.tf.browser.fromPixels(imageElement); // 调整大小并标准化 let expectedSize = IMAGE_SIZE; try { if ( this.mobilenetModel && this.mobilenetModel.inputs && this.mobilenetModel.inputs[0].shape ) { const shape = this.mobilenetModel.inputs[0].shape; if ( shape && shape.length >= 3 && shape[1] != null && shape[2] != null && shape[1] === shape[2] ) { expectedSize = shape[1]; console.log(`检测到模型输入尺寸: ${expectedSize}x${expectedSize}`); } else { console.warn( `无法从模型输入形状 ${JSON.stringify( shape )} 推断尺寸,使用默认值 ${IMAGE_SIZE}。` ); } } } catch (e) { console.warn( `获取模型输入尺寸时出错,使用默认值 ${IMAGE_SIZE}。错误: ${e}` ); } let resizedTensor = tf.image .resizeBilinear(imageTensor, [expectedSize, expectedSize], true) .expandDims(0); let normalizedTensor = resizedTensor.toFloat().div(127.5).sub(1.0); // 特征提取 let logits = null; try { if (typeof this.mobilenetModel.infer === "function") { logits = this.mobilenetModel.infer(normalizedTensor, true); } else if (typeof this.mobilenetModel.predict === "function") { logits = this.mobilenetModel.predict(normalizedTensor); if (Array.isArray(logits)) { logits = logits[0]; } if (logits.rank > 2) { const batchSize = logits.shape[0]; logits = logits.reshape([batchSize, -1]); } } else { throw new Error("无法识别的模型接口"); } } catch (error) { console.error("MobileNet 推理失败: ", error); return null; // 返回 null 表示失败 } // *** tidy 会自动清理 imageTensor, resizedTensor, normalizedTensor *** // *** 但 logits 因为被返回,所以不会被这个 tidy 清理 *** return logits; }); // *** tidy 结束 *** } // 添加训练样本 (修改后) async addImageForTraining(classIndex) { if (!this.isAppReady) { alert("模型尚未准备好,请稍候。"); return; } if ( !this.imagePreview || this.imagePreview.style.display === "none" || !this.imagePreview.src || this.imagePreview.src.endsWith("#") ) { alert("请先选择一张图片。"); return; } // 1. 获取特征 (内部使用 tidy) const logits = await this.getFeaturesFromImage(this.imagePreview); if (logits) { // 2. 执行异步 KNN 操作 (不在 tidy 内部) this.knn.addExample(logits, classIndex); this.updateExampleCountUI(classIndex); // 3. **手动释放不再需要的 logits** (因为它被 getFeaturesFromImage 的 tidy 返回了) logits.dispose(); } else { this.updateStatus(`错误:无法从图片提取特征用于训练物品 ${classIndex}`); } } // 分类当前图片 (修改后) async classifyCurrentImage() { if (!this.isAppReady) { alert("模型尚未准备好,请稍候。"); return; } if ( !this.imagePreview || this.imagePreview.style.display === "none" || !this.imagePreview.src || this.imagePreview.src.endsWith("#") ) { alert("请先选择一张图片。"); return; } this.updateStatus("正在识别图片..."); const numClasses = this.knn.getNumClasses(); if (numClasses === 0) { this.updateStatus("错误:您还没有训练任何物品类别!"); alert("请先训练至少一个物品类别再进行识别。"); return; } // 1. 获取特征 (内部使用 tidy) const logits = await this.getFeaturesFromImage(this.imagePreview); if (logits) { // 2. 执行异步 KNN 操作 (不在 tidy 内部) const res = await this.knn.predictClass(logits, TOPK); this.updatePredictionUI(res); this.updateStatus( `识别完成。最可能的类别: 物品 ${res.classIndex} (${( res.confidences[res.classIndex] * 100 ).toFixed(1)}%)` ); // 3. **手动释放不再需要的 logits** logits.dispose(); } else { this.updateStatus("错误:无法从图片提取特征进行识别。"); } } // 更新特定类别的样本计数UI (保持不变) updateExampleCountUI(classIndex) { const exampleCount = this.knn.getClassExampleCount(); if (this.infoTexts[classIndex]) { const count = exampleCount[classIndex] || 0; const currentText = this.infoTexts[classIndex].innerText; const parts = currentText.split(" - "); this.infoTexts[classIndex].innerText = ` ${count} 个学习样例` + (parts.length > 1 ? ` - ${parts[1]}` : ""); this.infoTexts[classIndex].style.fontWeight = "normal"; } } // 更新所有类别的预测UI updatePredictionUI(predictionResult) { const exampleCount = this.knn.getClassExampleCount(); for (let i = 0; i < NUM_CLASSES; i++) { const infoText = this.infoTexts[i]; if (infoText) { // 确保元素存在 const count = exampleCount[i] || 0; // 使用 ?. 确保 confidences 存在且属性存在 const confidence = predictionResult?.confidences?.[i] != null ? (predictionResult.confidences[i] * 100).toFixed(1) : "0.0"; // 默认为 0% infoText.innerText = ` ${count} 个学习样例 - ${confidence}%`; // 加粗预测结果 if (predictionResult.classIndex === i) { infoText.style.fontWeight = "bold"; } else { infoText.style.fontWeight = "normal"; } } } } } // 页面加载完成后启动应用 window.addEventListener("load", () => new Main());