ตัวอย่างการใช้งาน Java Neuroph ด้วยการประยุกต์แทรก Prompt ไปผสมกับ Latent Vector เพื่อใช้เป็นตัวกำหนดผลลัพท์ร่วมนะครับ โดยยกตัวอย่าง Prompt สำหรับงาน Graphic เช่น GrayScale และ Sepia
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.input.WeightedSum;
import org.neuroph.nnet.comp.neuron.InputNeuron;
import org.neuroph.nnet.comp.layer.InputLayer;
import org.neuroph.nnet.comp.layer.FullConnectedLayer;
import java.util.Arrays;
public class AutoencoderWithMiddlePrompt {
public static void main(String[] args) {
InputLayer inputLayer = new InputLayer(4);
InputLayer promptLayer = new InputLayer(2);
FullConnectedLayer encoderLayer = new FullConnectedLayer(3);
FullConnectedLayer middleLayer = new FullConnectedLayer(5);
FullConnectedLayer decoderLayer = new FullConnectedLayer(4);
connectLayers(inputLayer, encoderLayer);
connectLayers(encoderLayer, middleLayer);
connectLayers(promptLayer, middleLayer);
connectLayers(middleLayer, decoderLayer);
NeuralNetwork autoencoder = new NeuralNetwork();
autoencoder.addLayer(inputLayer);
autoencoder.addLayer(promptLayer);
autoencoder.addLayer(encoderLayer);
autoencoder.addLayer(middleLayer);
autoencoder.addLayer(decoderLayer);
autoencoder.setInputNeurons(inputLayer.getNeurons());
autoencoder.setOutputNeurons(decoderLayer.getNeurons());
double[][] trainingData = {
{0.0, 1.0, 0.0, 1.0},
{1.0, 0.0, 1.0, 0.0}
};
double[][] promptData = {
{1.0, 0.0},
{0.0, 1.0}
};
for (int epoch = 0; epoch < 1000; epoch++) {
for (int i = 0; i < trainingData.length; i++) {
double[] input = trainingData[i];
double[] prompt = promptData[i];
autoencoder.setInput(input);
autoencoder.calculate();
double[] latent = autoencoder.getLayerAt(encoderLayer.getIndex()).getNeuronsOutput();
double[] combinedLatent = combineVectors(latent, prompt);
autoencoder.setLayerInput(middleLayer, combinedLatent);
autoencoder.calculate();
autoencoder.learn(input);
}
}
double[] testInput = {1.0, 0.0, 1.0, 0.0};
double[] grayscalePrompt = {1.0, 0.0};
double[] sepiaPrompt = {0.0, 1.0};
System.out.println("Testing with Grayscale Prompt:");
testWithPrompt(autoencoder, testInput, grayscalePrompt);
System.out.println("Testing with Sepia Prompt:");
testWithPrompt(autoencoder, testInput, sepiaPrompt);
}
private static void connectLayers(Layer from, Layer to) {
for (Neuron fromNeuron : from.getNeurons()) {
for (Neuron toNeuron : to.getNeurons()) {
toNeuron.addInputConnection(fromNeuron);
}
}
}
private static double[] combineVectors(double[] vector1, double[] vector2) {
double[] combined = new double[vector1.length + vector2.length];
System.arraycopy(vector1, 0, combined, 0, vector1.length);
System.arraycopy(vector2, 0, combined, vector1.length, vector2.length);
return combined;
}
private static void testWithPrompt(NeuralNetwork autoencoder, double[] input, double[] prompt) {
autoencoder.setInput(input);
autoencoder.calculate();
double[] latent = autoencoder.getLayerAt(1).getNeuronsOutput();
double[] combinedLatent = combineVectors(latent, prompt);
autoencoder.setLayerInput(2, combinedLatent);
autoencoder.calculate();
System.out.println("Reconstructed Output: " + Arrays.toString(autoencoder.getOutput()));
}
}
ความคิดเห็น
แสดงความคิดเห็น