InMemoryClassLoader.java
/*-
* #%L
* io.earcam.instrumental.lade
* %%
* Copyright (C) 2018 earcam
* %%
* SPDX-License-Identifier: (BSD-3-Clause OR EPL-1.0 OR Apache-2.0 OR MIT)
*
* You <b>must</b> choose to accept, in full - any individual or combination of
* the following licenses:
* <ul>
* <li><a href="https://opensource.org/licenses/BSD-3-Clause">BSD-3-Clause</a></li>
* <li><a href="https://www.eclipse.org/legal/epl-v10.html">EPL-1.0</a></li>
* <li><a href="https://www.apache.org/licenses/LICENSE-2.0">Apache-2.0</a></li>
* <li><a href="https://opensource.org/licenses/MIT">MIT</a></li>
* </ul>
* #L%
*/
package io.earcam.instrumental.lade;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyEnumeration;
import static java.util.Collections.emptyMap;
import static java.util.Collections.singleton;
import static java.util.Locale.ROOT;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.ref.WeakReference;
import java.net.URL;
import java.security.CodeSigner;
import java.security.CodeSource;
import java.security.SecureClassLoader;
import java.security.cert.X509Certificate;
import java.util.Collection;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.jar.JarEntry;
import java.util.jar.JarInputStream;
import java.util.jar.JarOutputStream;
import java.util.jar.Manifest;
import java.util.stream.Stream;
import javax.annotation.concurrent.ThreadSafe;
import io.earcam.unexceptional.Exceptional;
import io.earcam.utilitarian.security.Signatures;
/**
* <p>
* A {@link java.lang.ClassLoader} that doesn't require a file system.
* </p>
*
* <p>
* Class/resource loading strategy is self-first. With multiple JARs added,
* the search is linear in-order; so "hiding" is possible as per classpath.
* </p>
*
* <p>
* Current partial support for signed JARs; the certificate will be loaded
* and {@link CodeSigner}/{@link CodeSource} associated with loaded classes,
* however no validation is performed. RSA only.
* </p>
*/
@ThreadSafe
public final class InMemoryClassLoader extends SecureClassLoader implements AutoCloseable {
static {
InMemoryClassLoader.registerAsParallelCapable();
}
static final String MANIFEST_PATH = "META-INF/MANIFEST.MF";
static final ConcurrentMap<String, WeakReference<InMemoryClassLoader>> loaders = new ConcurrentHashMap<>();
/**
* The URL protocol used by resources from this {@link ClassLoader}.
* Note: this is only valid for URLs returned by this instance as they
* include a handler which maintain a reference to this instance.
*
* You may register the protocol handler to load URLs within the same JVM.
*
* @see Handler#addProtocolHandlerSystemProperty()
*/
public static final String URL_PROTOCOL = "lade";
// Map of resource-name to map of resource-owner-name to resource
private final ConcurrentMap<String, Map<String, byte[]>> resources = new ConcurrentHashMap<>();
private final ConcurrentMap<String, CodeSource> codeSources = new ConcurrentHashMap<>();
private final ConcurrentMap<String, Class<?>> loaded = new ConcurrentHashMap<>();
private final Handler handler = new Handler();
private boolean addCodeSources = true;
/**
* Exposed for use as the "<tt>java.system.class.loader</tt>"
*
* @param parent a {@link java.lang.ClassLoader} instance.
* @see ClassLoader#getSystemClassLoader()
*/
public InMemoryClassLoader(ClassLoader parent)
{
super(parent);
loaders.put(identityHashCodeHex(this), new WeakReference<InMemoryClassLoader>(this));
}
/**
* <p>
* By default the {@link InMemoryClassLoader} adds the {@link CodeSigner}/{@link CodeSource}
* details where found.
* </p>
* <p>
* Invoking this method disables this behaviour, but only for JARs added subsequently.
* </p>
*
* @return
*/
public InMemoryClassLoader doNotAddSubsequentSignatures()
{
this.addCodeSources = false;
return this;
}
/**
* <p>
* jars.
* </p>
*
* @param jars an array of {@link byte} objects.
* @return a {@link io.earcam.instrumental.lade.InMemoryClassLoader} object.
*/
public InMemoryClassLoader jars(byte[]... jars)
{
return jars(asList(jars));
}
/**
* <p>
* jars.
* </p>
*
* @param jars a {@link java.util.Collection} object.
* @return a {@link io.earcam.instrumental.lade.InMemoryClassLoader} object.
*/
public InMemoryClassLoader jars(Collection<byte[]> jars)
{
return jars(jars.stream());
}
/**
* <p>
* jars.
* </p>
*
* @param jars a {@link java.util.stream.Stream} object.
* @return a {@link io.earcam.instrumental.lade.InMemoryClassLoader} object.
*/
public InMemoryClassLoader jars(Stream<byte[]> jars)
{
jars.forEach(j -> jar(j, identityHashCodeHex(j)));
return this;
}
private static String identityHashCodeHex(Object object)
{
return Integer.toString(System.identityHashCode(object), 16).toUpperCase(ROOT);
}
/**
* <p>
* jar.
* </p>
*
* @param jar an array of {@link byte} objects.
* @return a {@link io.earcam.instrumental.lade.InMemoryClassLoader} object.
*/
public InMemoryClassLoader jar(byte[] jar)
{
return jars(singleton(jar));
}
/**
* <p>
* jar.
* </p>
*
* @param jar adds the byte array representing a jar's contents
* @param name the "name" to use for this jar (for URL resource references etc)
* @return a {@link io.earcam.instrumental.lade.InMemoryClassLoader} object.
*/
public InMemoryClassLoader jar(byte[] jar, String name)
{
jar(new ByteArrayInputStream(jar), name);
return this;
}
/**
* <p>
* jar.
* </p>
*
* @param jar adds the {@link java.io.InputStream} representing a jar's contents
* @param name the "name" to use for this jar (for URL resource references etc)
* @return a {@link io.earcam.instrumental.lade.InMemoryClassLoader} object.
*/
public InMemoryClassLoader jar(InputStream jar, String name)
{
Exceptional.run(() -> jar(new JarInputStream(jar), name));
return this;
}
private InMemoryClassLoader jar(JarInputStream jar, String name) throws IOException
{
addManifest(jar, name);
addJarEntries(jar, name);
if(addCodeSources) {
addCodeSource(name);
}
return this;
}
void addManifest(JarInputStream jar, String name) throws IOException
{
if(jar.getManifest() != null) {
ByteArrayOutputStream os = new ByteArrayOutputStream();
jar.getManifest().write(os);
resources.computeIfAbsent(MANIFEST_PATH, k -> new ConcurrentHashMap<>()).put(name, os.toByteArray());
}
}
void addJarEntries(JarInputStream jar, String name) throws IOException
{
JarEntry jarEntry;
try(JarInputStream wrap = jar) {
while((jarEntry = wrap.getNextJarEntry()) != null) {
resources.computeIfAbsent(jarEntry.getName(), k -> new ConcurrentHashMap<>()).put(name, inputStreamToBytes(wrap));
}
}
}
private void addCodeSource(String name)
{
byte[] bytes = resources.keySet().stream()
.filter(r -> r.startsWith("META-INF/") && r.endsWith(".SF"))
.map(r -> r.substring(0, r.length() - 2) + "RSA")
.map(this::resourcesForName)
.map(m -> m.get(name))
.filter(Objects::nonNull)
.findFirst().orElse(new byte[0]);
if(bytes.length != 0) {
X509Certificate[] certificates = Signatures.certificatesFromSignature(bytes).stream().toArray(s -> new X509Certificate[s]);
codeSources.put(name, new CodeSource(createResourceUrl(name, ""), certificates));
}
}
static byte[] inputStreamToBytes(InputStream input) throws IOException
{
ByteArrayOutputStream baos = new ByteArrayOutputStream();
int read = input.read();
while(read != -1) {
baos.write(read);
read = input.read();
}
return baos.toByteArray();
}
@Override
public Class<?> loadClass(String name) throws ClassNotFoundException
{
Class<?> wasLoaded = loaded.get(name);
if(wasLoaded != null) {
return wasLoaded;
}
String className = classToResourceName(name);
Iterator<Entry<String, byte[]>> iterator = resourcesForName(className).entrySet().iterator();
if(!iterator.hasNext()) {
if(getParent() == null) {
throw new ClassNotFoundException(name);
}
return getParent().loadClass(name);
}
Entry<String, byte[]> resource = iterator.next();
byte[] bytes = resource.getValue();
Class<?> defined = defineClass(name, bytes, 0, bytes.length, codeSources.get(resource.getKey()));
loaded.put(name, defined);
return defined;
}
private static String classToResourceName(String name)
{
return name.replaceAll("\\.", "/") + ".class";
}
private Map<String, byte[]> resourcesForName(String resource)
{
return resources.getOrDefault(stripLeadingSlash(resource), emptyMap());
}
@Override
public void close()
{
resources.clear();
loaded.clear();
loaders.remove(identityHashCodeHex(this));
}
@Override
public InputStream getResourceAsStream(String name)
{
Iterator<byte[]> iterator = resourcesForName(name).values().iterator();
if(!iterator.hasNext()) {
return super.getResourceAsStream(name);
}
return new ByteArrayInputStream(iterator.next());
}
private static String stripLeadingSlash(String name)
{
return (name.length() > 0 && name.charAt(0) == '/') ? name.substring(1) : name;
}
static InputStream getResourceAsStream(URL earl) throws IOException
{
String host = earl.getHost();
int index = host.indexOf('.');
String loaderId = host.substring(0, index);
String archiveId = host.substring(index + 1);
InMemoryClassLoader loader = InMemoryClassLoader.loaders.get(loaderId).get();
String resource = stripLeadingSlash(earl.getPath());
if(resource.isEmpty()) {
// We must now serialize the entire JAR for the given archiveId ... as something
// is trying to look up the JAR having trimmed off the resource path
return loader.serializeJar(archiveId);
}
byte[] bytes = loader.resourcesForName(resource).get(archiveId);
return (bytes == null) ? null : new ByteArrayInputStream(bytes);
}
private InputStream serializeJar(String archiveId) throws IOException
{
byte[] manifestBytes = resourcesForName(MANIFEST_PATH).get(archiveId);
Manifest manifest = new Manifest(new ByteArrayInputStream(manifestBytes));
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try(JarOutputStream output = new JarOutputStream(baos, manifest)) {
for(Entry<String, Map<String, byte[]>> entrySet : resources.entrySet()) {
if(!MANIFEST_PATH.equals(entrySet.getKey()) && entrySet.getValue().containsKey(archiveId)) {
JarEntry jarEntry = new JarEntry(entrySet.getKey());
byte[] bytes = entrySet.getValue().get(archiveId);
jarEntry.setSize(bytes.length);
output.putNextEntry(jarEntry);
output.write(bytes);
}
}
}
return new ByteArrayInputStream(baos.toByteArray());
}
@Override
public URL getResource(String name)
{
Enumeration<URL> earls;
earls = getResources(name);
if(earls.hasMoreElements()) {
return earls.nextElement();
}
return super.getResource(name);
}
@Override
public Enumeration<URL> getResources(String name)
{
return getResources(name, true);
}
public Enumeration<URL> getResources(String name, boolean includeParents)
{
Iterator<Entry<String, byte[]>> iterator = resourcesForName(name).entrySet().iterator();
Enumeration<URL> superResources = includeParents ? Exceptional.apply(super::getResources, name) : emptyEnumeration();
return new Enumeration<URL>() {
@Override
public boolean hasMoreElements()
{
return iterator.hasNext() || superResources.hasMoreElements();
}
@Override
public URL nextElement()
{
if(iterator.hasNext()) {
String archiveId = iterator.next().getKey();
return createResourceUrl(archiveId, name);
}
return superResources.nextElement();
}
};
}
public Stream<URL> resources()
{
return resources.entrySet()
.stream()
.flatMap(e -> e.getValue().entrySet().stream().map(r -> createResourceUrl(r.getKey(), e.getKey())));
}
private URL createResourceUrl(String archiveId, String path)
{
String host = identityHashCodeHex(this) + '.' + archiveId;
return Exceptional.url(URL_PROTOCOL, host, 0, "/" + path, handler);
}
public Class<?> define(String name, byte[] bytes, CodeSource codeSource)
{
return define(name, bytes, 0, bytes.length, codeSource);
}
public Class<?> define(String name, byte[] bytes, int offset, int length, CodeSource codeSource)
{
return defineClass(name, bytes, offset, length, codeSource);
}
}