个人随笔
目录
纯手写实现Spring IOC
2020-09-11 22:20:45

上一篇博文,我们借助Spring的AOP实现了自定义事务注解,这里呢,我们来实现一个简单的IOC依赖注入功能,一种是通过配置文件,一种是通过注解方式,其实不要觉得Spring IOC有多高大上,其实大家学过反射都知道,当然也是通过反射来实现的。下面开始吧,我只说下原理和贴下代码哦,因为我觉得最重要的就是知道原理,靠死记代码是没有意义的,知道原理才是王道,我觉得Spring只是在这个原理的实现基础上加了设计模式和一些技巧。接下来开始吧。

一、实现原理

都是通过java的类路径反射机制来实例化类放入一个全局变量中(map)。通过先实例化bean,后处理属性的依赖来解决属性循环依赖的问题。

二、实现步骤

配置文件方式

  1. 先将配置文件的bean信息读取出来到一个list中去。list中是实例化beanid,类路径,属性集合。
  2. 循环处理第一步获取的list,根据java的反射机制调用类默认的构造方法实例化类,将实例化后的类存放在一个全局map中,keyid,值是对象。
  3. 循环第一步的list,处理属性依赖,因为实例都已经初始化了,所以不会有循环依赖锁死的问题。

注解方式

  1. 先定义两个注解,一个是Bean,一个是Source,Bean的作用是表明这个类会在容器启动的时候实例化,Source的作用是表明这个属性会在容器启动阶段初始化。
  2. 根据用户传过来的包名,通过遍历的方式去获取包及子包下面的所有类对象。
  3. 遍历第二步获取的类对象,判断类上是否有@Bean注解,若是有则加入到一个集合中去。
  4. 遍历第3步获取的集合,根据反射技术去实例化类,放入一个全局对象map中。key为类名称首字母小写。
  5. 遍历map,处理对象的属性依赖,通过获取类的字节码以及通过类的字节码获取类的所有属性值。
  6. 遍历属性,判断属性上是否有@Source注解,有的话,就根据属性名去map中获取属性对应的实例。
  7. 设置属性的修改权限为true,这样的话就算是private都可以设置。
  8. 调用Field设置属性的值。

三、示例代码

1、项目结构

文件 作用
pom.xml 项目的基本依赖
spring.xml 配置文件方式实现的所需配置
AppSpring.java 启动测试类
ClassUtil.java 根据路径,获取该路径下所有Class的工具类
ConfigInfo.java 注解方式实现的配置文件信息类,spring.xml中的信息就是读到这
MySpringCore.java 核心类,在这里实现两种初始化方法,配置文件和注解,类似于ClassPathXmlApplicationContext.java
User.java、Dog.java 要实例化的两个bean
Bean.java 注解方式实现类的@Bean注解,加了这个注解的类自动实例化
Source.java 注解方式实现类的@Source注解,加上这个注解的属性自动注入,类似有@Autowired

2、pom.xml

  1. <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
  2. <modelVersion>4.0.0</modelVersion>
  3. <groupId>cn.myforever</groupId>
  4. <artifactId>my-spring</artifactId>
  5. <version>0.0.1-SNAPSHOT</version>
  6. <dependencies>
  7. <dependency>
  8. <groupId>dom4j</groupId>
  9. <artifactId>dom4j</artifactId>
  10. <version>1.6.1</version>
  11. </dependency>
  12. <dependency>
  13. <groupId>commons-lang</groupId>
  14. <artifactId>commons-lang</artifactId>
  15. <version>2.6</version>
  16. </dependency>
  17. </dependencies>
  18. </project>

这里只加上了一些工具类的依赖,像解析xml文件的dom4j

3、spring.xml

<?xml version="1.0" encoding="UTF-8"?>  
<!-- 因为是强制要求的框架,所以这里不需要dtd和项目了schema,只能是bean,否则报错 -->
<beans>
    <bean id="user" class="cn.myforever.bean.User">
        <property ref="dog"></property>
    </bean>
    <bean id="dog" class="cn.myforever.bean.Dog">
        <property id="user" ref="user"></property>
    </bean>
</beans>

这个只是简单的模仿Spring的配置文件写了两个bean

4、Bean.java

/**
 * 加入了这个注解的将会自动初始化bean
 * @author forever
 *
 */
//在类,接口enum上声明
@Target(value= {ElementType.TYPE})
//运行期有效
@Retention(RetentionPolicy.RUNTIME)
public @interface Bean {

}

5、Source.java

/**
 * 加了这个的属性,会自动根据首字母小写注入对象
 * @author forever
 *
 */
@Target(value= {ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
public @interface Source {

}

6、Dog.java

@Bean
public class Dog {
    @Source
    private User user;
    public void print() {
        System.out.println("我是狗子");
    }
    public void MyUser() {
        user.print();
    }
}

在用注解方式的时候,上面的注解才会生效

7、User.java

@Bean
public class User {
    @Source
    private Dog dog;
    public void print(){
        System.out.println("我是主子");
    }
    public void myDog() {
        dog.print();
    }
}

在用注解方式的时候,上面的注解才会生效

8、ConfigInfo.java

/**
 * 配置文件信息,配置文件中的信息读取出来放到这里
 * @author forever
 *
 */
public class ConfigInfo {
    private String tagName;
    private String id;//map中的key,根据此key来找对象
    private String clazz;//类全名
    private List<String> propertys;//属性

    public String getTagName() {
        return tagName;
    }
    public void setTagName(String tagName) {
        this.tagName = tagName;
    }
    public String getId() {
        return id;
    }
    public void setId(String id) {
        this.id = id;
    }
    public String getClazz() {
        return clazz;
    }
    public void setClazz(String clazz) {
        this.clazz = clazz;
    }
    public List<String> getPropertys() {
        return propertys;
    }
    public void setPropertys(List<String> propertys) {
        this.propertys = propertys;
    }
    @Override
    public String toString() {
        return "ConfigInfo [tagName=" + tagName + ", id=" + id + ", clazz=" + clazz + ", propertys=" + propertys + "]";
    }

}

这个类主要是加载配置文件,把配置文件的信息加载到这个类中,然构成后List

9、AppSpring.java

启动类,我们这里只支持一种启动方式,要不配置文件,要不就注解。

public class AppSpring {
    public static void main(String[] args) {
        //String path = "spring.xml";
        String path = "cn.myforever.bean";
        try {
            MySpringCore app = new MySpringCore(path);
            User user = (User) app.getBean("user");
            user.myDog();
            Dog dog = (Dog) app.getBean("dog");
            dog.MyUser();
        } catch (Exception e) {
            e.printStackTrace();
        }

    }

}

10、ClassUtil.java

工具类,通过路径遍历class。

public class ClassUtil {

    /**
     * 取得某个接口下所有实现这个接口的类
     */
    public static List<Class> getAllClassByInterface(Class c) {
        List<Class> returnClassList = null;

        if (c.isInterface()) {
            // 获取当前的包名
            String packageName = c.getPackage().getName();
            // 获取当前包下以及子包下所以的类
            List<Class<?>> allClass = getClasses(packageName);
            if (allClass != null) {
                returnClassList = new ArrayList<Class>();
                for (Class classes : allClass) {
                    // 判断是否是同一个接口
                    if (c.isAssignableFrom(classes)) {
                        // 本身不加入进去
                        if (!c.equals(classes)) {
                            returnClassList.add(classes);
                        }
                    }
                }
            }
        }

        return returnClassList;
    }

    /*
     * 取得某一类所在包的所有类名 不含迭代
     */
    public static String[] getPackageAllClassName(String classLocation, String packageName) {
        // 将packageName分解
        String[] packagePathSplit = packageName.split("[.]");
        String realClassLocation = classLocation;
        int packageLength = packagePathSplit.length;
        for (int i = 0; i < packageLength; i++) {
            realClassLocation = realClassLocation + File.separator + packagePathSplit[i];
        }
        File packeageDir = new File(realClassLocation);
        if (packeageDir.isDirectory()) {
            String[] allClassName = packeageDir.list();
            return allClassName;
        }
        return null;
    }

    /**
     * 从包package中获取所有的Class
     * 
     * @param pack
     * @return
     */
    public static List<Class<?>> getClasses(String packageName) {

        // 第一个class类的集合
        List<Class<?>> classes = new ArrayList<Class<?>>();
        // 是否循环迭代
        boolean recursive = true;
        // 获取包的名字 并进行替换
        String packageDirName = packageName.replace('.', '/');
        // 定义一个枚举的集合 并进行循环来处理这个目录下的things
        Enumeration<URL> dirs;
        try {
            dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
            // 循环迭代下去
            while (dirs.hasMoreElements()) {
                // 获取下一个元素
                URL url = dirs.nextElement();
                // 得到协议的名称
                String protocol = url.getProtocol();
                // 如果是以文件的形式保存在服务器上
                if ("file".equals(protocol)) {
                    // 获取包的物理路径
                    String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                    // 以文件的方式扫描整个包下的文件 并添加到集合中
                    findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
                } else if ("jar".equals(protocol)) {
                    // 如果是jar包文件
                    // 定义一个JarFile
                    JarFile jar;
                    try {
                        // 获取jar
                        jar = ((JarURLConnection) url.openConnection()).getJarFile();
                        // 从此jar包 得到一个枚举类
                        Enumeration<JarEntry> entries = jar.entries();
                        // 同样的进行循环迭代
                        while (entries.hasMoreElements()) {
                            // 获取jar里的一个实体 可以是目录 和一些jar包里的其他文件 如META-INF等文件
                            JarEntry entry = entries.nextElement();
                            String name = entry.getName();
                            // 如果是以/开头的
                            if (name.charAt(0) == '/') {
                                // 获取后面的字符串
                                name = name.substring(1);
                            }
                            // 如果前半部分和定义的包名相同
                            if (name.startsWith(packageDirName)) {
                                int idx = name.lastIndexOf('/');
                                // 如果以"/"结尾 是一个包
                                if (idx != -1) {
                                    // 获取包名 把"/"替换成"."
                                    packageName = name.substring(0, idx).replace('/', '.');
                                }
                                // 如果可以迭代下去 并且是一个包
                                if ((idx != -1) || recursive) {
                                    // 如果是一个.class文件 而且不是目录
                                    if (name.endsWith(".class") && !entry.isDirectory()) {
                                        // 去掉后面的".class" 获取真正的类名
                                        String className = name.substring(packageName.length() + 1, name.length() - 6);
                                        try {
                                            // 添加到classes
                                            classes.add(Class.forName(packageName + '.' + className));
                                        } catch (ClassNotFoundException e) {
                                            e.printStackTrace();
                                        }
                                    }
                                }
                            }
                        }
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }

        return classes;
    }

    /**
     * 以文件的形式来获取包下的所有Class
     * 
     * @param packageName
     * @param packagePath
     * @param recursive
     * @param classes
     */
    public static void findAndAddClassesInPackageByFile(String packageName, String packagePath, final boolean recursive,
            List<Class<?>> classes) {
        // 获取此包的目录 建立一个File
        File dir = new File(packagePath);
        // 如果不存在或者 也不是目录就直接返回
        if (!dir.exists() || !dir.isDirectory()) {
            return;
        }
        // 如果存在 就获取包下的所有文件 包括目录
        File[] dirfiles = dir.listFiles(new FileFilter() {
            // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
            public boolean accept(File file) {
                return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
            }
        });
        // 循环所有文件
        for (File file : dirfiles) {
            // 如果是目录 则继续扫描
            if (file.isDirectory()) {
                findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive,
                        classes);
            } else {
                // 如果是java类文件 去掉后面的.class 只留下类名
                String className = file.getName().substring(0, file.getName().length() - 6);
                try {
                    // 添加到集合中去
                    classes.add(Class.forName(packageName + '.' + className));
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }
            }
        }
    }
}

当然是网上找的啦。

11、MySpringCore.java

该类是核心类,这里实现了两种方式,通过构造方法传入的值判断是用配置文件的方式启动,还是用注解的方式启动。都有注释,相信认真看下再结合一开始说的原理步骤还是可以看得懂的,实在不行就自己跑一下。

/**
 * 自定义简单的spring IOC和DI框架,支持xml和注解配置方式
 * 在容易启动的时候就会实例化bean,而不是加载的时候实例化,默认属性都是首字母小写
 * @author forever
 */
public class MySpringCore {
    private ConcurrentHashMap<String, Object> map =new ConcurrentHashMap<String, Object>();
    //配置文件信息,启动应该是单线程的,不会有并发问题
    private List<ConfigInfo> configInfos = new ArrayList<ConfigInfo>(); 
    /**
     * 初始化spring容器,若是path以.xml结尾,则表明是采用配置文件方式,否则,就表明是采用注解方式
     * 传进来的就是一个包路径,配置文件方式默认读取classpath:spring.xml
     * @param path
     * @throws Exception 
     */
    public MySpringCore(String path) throws Exception {
        //1、判断是配置文件方式还是注解的方式
        if(path.contains(".xml")) {
            System.out.println("配置文件方式启动");
            initBeanByXml(path);
        }else {
            System.out.println("非配置文件方式启动,那么应该是注解的方式启动");
            //这里就要进行扫包,然后获取该包下的所有class
            initBeanByPacakage(path);
        }
    }
    /**
     * 获取bean
     * @param beanName
     * @return
     */
    public Object getBean(String beanName) {
        return map.get(beanName);
    }
    /**
     * 初始化bean
     * @param path xml文件路径
     * @throws Exception 
     */
    private void initBeanByXml(String path) throws Exception {
        //1、去获取配置文件的所有bean节点信息,这里只获取bean
        List<Element> elements = readXml(path);
        if(elements==null) {
            return;
        }
        //这里就表明,配置文件中配置有bean,初始化bean
        initConfigByElements(elements);
        System.out.println("配置文件加载成功:"+configInfos.toString());
        //这一步是提前实例化bean,防止循环依赖
        initBeanByConfigInfos();
        System.out.println("容器初始化bean成功:"+map.toString());
        //处理依赖
        dealDependency();
        System.out.println("容器初始化bean依赖成功:"+map.toString());

    }
    /**
     * 处理依赖
     * @throws SecurityException 
     * @throws NoSuchFieldException 
     * @throws IllegalAccessException 
     * @throws IllegalArgumentException 
     */
    private void dealDependency() throws NoSuchFieldException, SecurityException, IllegalArgumentException, IllegalAccessException {
        //遍历配置文件,看看哪些是有属性的
        for (ConfigInfo configInfo : configInfos) {
            if("bean".equals(configInfo.getTagName())) {
                List<String> propertys = configInfo.getPropertys();
                Object obj = map.get(configInfo.getId());
                if(propertys!=null&&propertys.size()>0) {
                    for (String string : propertys) {
                        //获取属性对应的Object
                        String[] strs = string.split("#");
                        String id = strs[0];
                        String ref = strs[1];
                        Object object = map.get(ref);
                        //设置属性,不一定需要有get,set方法
                        Class<?> clazz  =obj.getClass();
                        //getDeclaredField是可以获取一个类的所有字段. 
                        //getField只能获取类的public 字段. 
                        //Field field = clazz.getField(id);
                        Field[] fields = clazz.getDeclaredFields();
                        for (Field field2 : fields) {
                            if(id.equals(field2.getName())) {
                                //这样就可以改动私有方法
                                field2.setAccessible(true);
                                field2.set(obj, object);
                            }
                        }
                    }
                }
            }
        }
    }
    /**
     * 更具配置文件信息初始化
     * @throws Exception 
     */
    private void initBeanByConfigInfos() throws Exception {
        for (ConfigInfo configInfo : configInfos) {
            //只处理bean标签
            if("bean".equals(configInfo.getTagName())){
                Class<?> clazz= Class.forName(configInfo.getClazz());
                if(clazz==null) {
                    throw new Exception(configInfo.getClazz()+"反射生成class失败");
                }
                Object object = clazz.newInstance();
                if(object==null) {
                    throw new Exception(configInfo.getClazz()+"实例化失败");
                }
                //将生产的Object放入map
                map.put(configInfo.getId(), object);
            }
        }
    }
    /**
     * 初始化配置文件
     * @param elements
     * @throws Exception 
     */
    @SuppressWarnings("unchecked")
    private void initConfigByElements(List<Element> elements) throws Exception {
        for (Element element : elements) {
            //这里只会初始化bean标签的元素
            if("bean".equals(element.getName())) {
                String id = element.attributeValue("id");
                String clazz = element.attributeValue("class");
                if(StringUtils.isBlank(id)||StringUtils.isBlank(clazz)) {
                    throw new Exception("bean的属性定义不规范");
                }
                ConfigInfo configInfo = new ConfigInfo();
                configInfo.setId(id);
                configInfo.setClazz(clazz);
                configInfo.setTagName(element.getName());
                //获取property属性文件,这里只会处理依赖
                List<Element> eles = element.elements();
                List<String> propertys= new ArrayList<String>();
                if(eles!=null&&eles.size()>0) {
                    for (Element ele : eles) {
                        if("property".equals((ele.getName()))) {
                            String ref = ele.attributeValue("ref");
                            String id2 = ele.attributeValue("id");
                            if(id2==null) {
                                id2=ref;
                            }
                            if(ref==null) {
                                throw new Exception("property标签必须有ref属性");
                            }
                            //这里就将属性值加入
                            propertys.add(id2+"#"+ref);
                        }
                    }
                }
                configInfo.setPropertys(propertys);
                configInfos.add(configInfo);
            }
        }
    }
    /**
     * 用dom4j解析配置文件
     * @param path
     * @return
     * @throws Exception 
     */
    private List<Element> readXml(String path) throws Exception {
        SAXReader saxReader  = new SAXReader();
        if(StringUtils.isBlank(path)) {
            throw new Exception("配置文件路径不能为空");
        }
        //构造Document对象
        Document doc = saxReader.read(getInputStreamFromPath(path));
        //获取根节点信息
        Element element=doc.getRootElement();
        //判断是否是beans
        String rootName = element.getName();
        if(!"beans".equals(rootName)) {
            throw new Exception("xml文件格式不对,根节点必须是beans");
        }
        //获取所有的子节点,子节点必须是bean,如果是用来实例化的话
        @SuppressWarnings("unchecked")
        List<Element> elements = element.elements();
        if(elements==null||elements.isEmpty()) {
            return null;
        }
        return elements;
    }
    /**
     * 默认去classpath下面找寻配置文件
     * @param path
     * @return
     */
    private InputStream getInputStreamFromPath(String path) {
        InputStream is = getClass().getClassLoader().getResourceAsStream(path);
        return is;
    }
    public MySpringCore() throws Exception {
        this("spring.xml");
    }
    //-------------------------------------------------------//
    //下面是以注解的方式启动
    private void initBeanByPacakage(String path) throws InstantiationException, IllegalAccessException {
        //用工具类扫包获取包及子包下面的所有类
        List<Class<?>> list= ClassUtil.getClasses(path);
        //若是没有类,则不处理
        if(list==null||list.size()<1) {
            return;
        }
        //获取所有加了@Bean的类
        List<Class<?>> haveBeanClass = findHaveBeanAnnotationClass(list);
        //初始化bean对象
        initBeanByClasses(haveBeanClass);
        //初始化依赖问题
        dealDependencyBySource();
        //初始化成功
        System.out.println("容器初始化bean依赖成功:"+map.toString());

    }
    /**
     * 处理依赖问题
     * @throws IllegalAccessException 
     * @throws IllegalArgumentException 
     */
    private void dealDependencyBySource() throws IllegalArgumentException, IllegalAccessException {
        for(Entry<String,Object> entry :map.entrySet()) {
            dealDependencyBySource(entry.getValue());
        }

    }
    /**
     * 处理属性依赖
     * @param value
     * @throws IllegalAccessException 
     * @throws IllegalArgumentException 
     */
    private void dealDependencyBySource(Object value) throws IllegalArgumentException, IllegalAccessException {
        Class<?> clazz = value.getClass();
        Field[] fields = clazz.getDeclaredFields();
        for (Field field : fields) {
            //判断属性上是否有@Source注解
            Source source = field.getAnnotation(Source.class);
            if(source!=null) {
                String name = field.getName();
                //去map获取需要注入的bean
                Object obj = map.get(name);
                //设置属性
                field.setAccessible(true);
                field.set(value, obj);
            }
        }

    }
    /**
     * 这些都是有@Bean注解的类,所以要初始化,如果用户要把@Bean加载接口上,那么初始化会报错,用户自己处理
     * @param haveBeanClass
     * @throws IllegalAccessException 
     * @throws InstantiationException 
     */
    private void initBeanByClasses(List<Class<?>> haveBeanClass) throws InstantiationException, IllegalAccessException {
        //循环实例化
        for (Class<?> class1 : haveBeanClass) {
            //获取类的名称,然后首字母小写变成key
            String name  =class1.getSimpleName();
            name = toLowerCaseFirstOne(name);
            //实例化
            Object object = class1.newInstance();
            map.put(name, object);
        }

    }
    // 首字母转小写
    public static String toLowerCaseFirstOne(String s) {
        if (Character.isLowerCase(s.charAt(0)))
            return s;
        else
            return (new StringBuilder()).append(Character.toLowerCase(s.charAt(0))).append(s.substring(1)).toString();
    }
    /**
     * 获取所有类上有@Bean注解的类
     * @param lists
     * @return
     */
    private List<Class<?>> findHaveBeanAnnotationClass(List<Class<?>> lists) {
        List<Class<?>> list = new ArrayList<Class<?>>();
        for (Class<?> class1 : lists) {
            if(class1.isAnnotationPresent(Bean.class)) {
                list.add(class1);
            }
        }
        return list;
    }
}

四、注意事项

本质上都是通过配置文件,和注解+反射来初始化bean,不需要通过用户自己new.
这里初始化的全部都是单例模式。内存中只会有一份,并且多例模式的循环依赖不好解决.
获取属性的时候必须用Field[] fields = clazz.getDeclaredFields();这个,若是用Field field = clazz.getField(id);会获取不到private.
全局map建议使用线程安全的ConcurrentHashMap<String, Object> map 来。

这里提供git地址:https://github.com/suibibk/my-spring.git

结语

学习任何技术都不要去死记代码,不出两个月绝对会忘记,我们只需要知道原理,那么我们便可以根据原理来实现出来。框架的原理也许都差不多,只不过它的代码重构以及加上了很多设计模式,使得代码更加的通用。:bowtie:

 316

啊!这个可能是世界上最丑的留言输入框功能~


当然,也是最丑的留言列表

有疑问发邮件到 : suibibk@qq.com 侵权立删
Copyright : 个人随笔   备案号 : 粤ICP备18099399号-2